From 452210f88ac1180710f93ebecbb6394e38b970fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9ophile=20Wallez?= Date: Tue, 7 May 2024 16:35:05 +0200 Subject: [PATCH] feat: add typeclasses for easier use of State.Typed (#19) * feat: add typeclasses for easier use of State.Typed * cleanup: improve type and function names in DY.Lib.State.Tagged * cleanup: improve `map_types` instance names * cleanup: use a dedicated type for PrivateKey's map keys --- .../nsl_pk/DY.Example.NSL.Debug.Printing.fst | 2 +- ...DY.Example.NSL.Protocol.Stateful.Proof.fst | 26 +-- .../DY.Example.NSL.Protocol.Stateful.fst | 24 +-- .../DY.Example.NSL.SecurityProperties.fst | 12 +- src/lib/state/DY.Lib.State.Map.fst | 177 ++++++++++-------- src/lib/state/DY.Lib.State.PKI.fst | 34 ++-- src/lib/state/DY.Lib.State.PrivateKeys.fst | 44 +++-- src/lib/state/DY.Lib.State.Tagged.fst | 132 ++++++------- src/lib/state/DY.Lib.State.Typed.fst | 120 ++++++------ src/lib/utils/DY.Lib.Printing.fst | 20 +- 10 files changed, 321 insertions(+), 270 deletions(-) diff --git a/examples/nsl_pk/DY.Example.NSL.Debug.Printing.fst b/examples/nsl_pk/DY.Example.NSL.Debug.Printing.fst index b3616c9..79f50b7 100644 --- a/examples/nsl_pk/DY.Example.NSL.Debug.Printing.fst +++ b/examples/nsl_pk/DY.Example.NSL.Debug.Printing.fst @@ -81,5 +81,5 @@ val get_nsl_trace_to_string_printers: bytes -> bytes -> trace_to_string_printers let get_nsl_trace_to_string_printers priv_key_alice priv_key_bob = trace_to_string_printers_builder (message_to_string priv_key_alice priv_key_bob) - [(nsl_session_tag, session_to_string)] + [(local_state_nsl_session.tag, session_to_string)] [(event_nsl_event.tag, event_to_string)] diff --git a/examples/nsl_pk/DY.Example.NSL.Protocol.Stateful.Proof.fst b/examples/nsl_pk/DY.Example.NSL.Protocol.Stateful.Proof.fst index 51436f6..14b190f 100644 --- a/examples/nsl_pk/DY.Example.NSL.Protocol.Stateful.Proof.fst +++ b/examples/nsl_pk/DY.Example.NSL.Protocol.Stateful.Proof.fst @@ -16,7 +16,7 @@ open DY.Example.NSL.Protocol.Stateful /// The (local) state predicate. -let nsl_session_pred: typed_session_pred nsl_session = { +let nsl_session_pred: local_state_predicate nsl_session = { pred = (fun tr prin sess_id st -> match st with | InitiatorSentMsg1 bob n_a -> ( @@ -82,9 +82,9 @@ let nsl_event_pred: event_predicate nsl_event = /// List of all local state predicates. let all_sessions = [ - (pki_tag, typed_session_pred_to_session_pred (map_session_invariant pki_pred)); - (private_keys_tag, typed_session_pred_to_session_pred (map_session_invariant private_keys_pred)); - (nsl_session_tag, typed_session_pred_to_session_pred nsl_session_pred); + pki_tag_and_invariant; + private_keys_tag_and_invariant; + (local_state_nsl_session.tag, local_state_predicate_to_local_bytes_state_predicate nsl_session_pred); ] /// List of all local event predicates. @@ -107,11 +107,11 @@ instance nsl_protocol_invs: protocol_invariants = { /// Lemmas that the global state predicate contains all the local ones -val all_sessions_has_all_sessions: unit -> Lemma (norm [delta_only [`%all_sessions; `%for_allP]; iota; zeta] (for_allP (has_session_pred nsl_protocol_invs) all_sessions)) +val all_sessions_has_all_sessions: unit -> Lemma (norm [delta_only [`%all_sessions; `%for_allP]; iota; zeta] (for_allP (has_local_bytes_state_predicate nsl_protocol_invs) all_sessions)) let all_sessions_has_all_sessions () = assert_norm(List.Tot.no_repeats_p (List.Tot.map fst (all_sessions))); - mk_global_session_pred_correct nsl_protocol_invs all_sessions; - norm_spec [delta_only [`%all_sessions; `%for_allP]; iota; zeta] (for_allP (has_session_pred nsl_protocol_invs) all_sessions) + mk_global_local_bytes_state_predicate_correct nsl_protocol_invs all_sessions; + norm_spec [delta_only [`%all_sessions; `%for_allP]; iota; zeta] (for_allP (has_local_bytes_state_predicate nsl_protocol_invs) all_sessions) val full_nsl_session_pred_has_pki_invariant: squash (has_pki_invariant nsl_protocol_invs) let full_nsl_session_pred_has_pki_invariant = all_sessions_has_all_sessions () @@ -119,7 +119,7 @@ let full_nsl_session_pred_has_pki_invariant = all_sessions_has_all_sessions () val full_nsl_session_pred_has_private_keys_invariant: squash (has_private_keys_invariant nsl_protocol_invs) let full_nsl_session_pred_has_private_keys_invariant = all_sessions_has_all_sessions () -val full_nsl_session_pred_has_nsl_invariant: squash (has_typed_session_pred nsl_protocol_invs (nsl_session_tag, nsl_session_pred)) +val full_nsl_session_pred_has_nsl_invariant: squash (has_local_state_predicate nsl_protocol_invs nsl_session_pred) let full_nsl_session_pred_has_nsl_invariant = all_sessions_has_all_sessions () /// Lemmas that the global event predicate contains all the local ones @@ -159,7 +159,7 @@ val send_msg1_proof: trace_invariant tr_out )) let send_msg1_proof tr global_sess_id alice sess_id = - match get_typed_state #nsl_session nsl_session_tag alice sess_id tr with + match get_state alice sess_id tr with | (Some (InitiatorSentMsg1 bob n_a), tr) -> ( match get_public_key alice global_sess_id.pki (PkEnc "NSL.PublicKey") bob tr with | (None, tr) -> () @@ -200,7 +200,7 @@ val send_msg2_proof: trace_invariant tr_out )) let send_msg2_proof tr global_sess_id bob sess_id = - match get_typed_state nsl_session_tag bob sess_id tr with + match get_state bob sess_id tr with | (Some (ResponderSentMsg2 alice n_a n_b), tr) -> ( match get_public_key bob global_sess_id.pki (PkEnc "NSL.PublicKey") alice tr with | (None, tr) -> () @@ -227,7 +227,7 @@ let prepare_msg3_proof tr global_sess_id alice sess_id msg_id = match get_private_key alice global_sess_id.private_keys (PkDec "NSL.PublicKey") tr with | (None, tr) -> () | (Some sk_a, tr) -> ( - match get_typed_state nsl_session_tag alice sess_id tr with + match get_state alice sess_id tr with | (Some (InitiatorSentMsg1 bob n_a), tr) -> ( decode_message2_proof tr alice bob msg sk_a n_a ) @@ -245,7 +245,7 @@ val send_msg3_proof: trace_invariant tr_out )) let send_msg3_proof tr global_sess_id alice sess_id = - match get_typed_state nsl_session_tag alice sess_id tr with + match get_state alice sess_id tr with | (Some (InitiatorSentMsg3 bob n_a n_b), tr) -> ( match get_public_key alice global_sess_id.pki (PkEnc "NSL.PublicKey") bob tr with | (None, tr) -> () @@ -289,7 +289,7 @@ let prepare_msg4 tr global_sess_id bob sess_id msg_id = match get_private_key bob global_sess_id.private_keys (PkDec "NSL.PublicKey") tr with | (None, tr) -> () | (Some sk_b, tr) -> ( - match get_typed_state nsl_session_tag bob sess_id tr with + match get_state bob sess_id tr with | (Some (ResponderSentMsg2 alice n_a n_b), tr) -> ( decode_message3_proof tr alice bob msg sk_b n_b; diff --git a/examples/nsl_pk/DY.Example.NSL.Protocol.Stateful.fst b/examples/nsl_pk/DY.Example.NSL.Protocol.Stateful.fst index bee834d..463caff 100644 --- a/examples/nsl_pk/DY.Example.NSL.Protocol.Stateful.fst +++ b/examples/nsl_pk/DY.Example.NSL.Protocol.Stateful.fst @@ -26,8 +26,10 @@ type nsl_session = instance nsl_session_parseable_serializeable: parseable_serializeable bytes nsl_session = mk_parseable_serializeable ps_nsl_session -val nsl_session_tag: string -let nsl_session_tag = "NSL.Session" +instance local_state_nsl_session: local_state nsl_session = { + tag = "NSL.Session"; + format = nsl_session_parseable_serializeable; +} (*** Event type ***) @@ -61,12 +63,12 @@ let prepare_msg1 alice bob = let* n_a = mk_rand NoUsage (join (principal_label alice) (principal_label bob)) 32 in trigger_event alice (Initiate1 alice bob n_a);* let* sess_id = new_session_id alice in - set_typed_state nsl_session_tag alice sess_id (InitiatorSentMsg1 bob n_a <: nsl_session);* + set_state alice sess_id (InitiatorSentMsg1 bob n_a <: nsl_session);* return sess_id val send_msg1: nsl_global_sess_ids -> principal -> nat -> crypto (option nat) let send_msg1 global_sess_id alice sess_id = - let*? st: nsl_session = get_typed_state nsl_session_tag alice sess_id in + let*? st: nsl_session = get_state alice sess_id in match st with | InitiatorSentMsg1 bob n_a -> ( let*? pk_b = get_public_key alice global_sess_id.pki (PkEnc "NSL.PublicKey") bob in @@ -85,12 +87,12 @@ let prepare_msg2 global_sess_id bob msg_id = let* n_b = mk_rand NoUsage (join (principal_label msg1.alice) (principal_label bob)) 32 in trigger_event bob (Respond1 msg1.alice bob msg1.n_a n_b);* let* sess_id = new_session_id bob in - set_typed_state nsl_session_tag bob sess_id (ResponderSentMsg2 msg1.alice msg1.n_a n_b <: nsl_session);* + set_state bob sess_id (ResponderSentMsg2 msg1.alice msg1.n_a n_b <: nsl_session);* return (Some sess_id) val send_msg2: nsl_global_sess_ids -> principal -> nat -> crypto (option nat) let send_msg2 global_sess_id bob sess_id = - let*? st: nsl_session = get_typed_state nsl_session_tag bob sess_id in + let*? st: nsl_session = get_state bob sess_id in match st with | ResponderSentMsg2 alice n_a n_b -> ( let*? pk_a = get_public_key bob global_sess_id.pki (PkEnc "NSL.PublicKey") alice in @@ -105,19 +107,19 @@ val prepare_msg3: nsl_global_sess_ids -> principal -> nat -> nat -> crypto (opti let prepare_msg3 global_sess_id alice sess_id msg_id = let*? msg = recv_msg msg_id in let*? sk_a = get_private_key alice global_sess_id.private_keys (PkDec "NSL.PublicKey") in - let*? st: nsl_session = get_typed_state nsl_session_tag alice sess_id in + let*? st: nsl_session = get_state alice sess_id in match st with | InitiatorSentMsg1 bob n_a -> ( let*? msg2: message2 = return (decode_message2 alice bob msg sk_a n_a) in trigger_event alice (Initiate2 alice bob n_a msg2.n_b);* - set_typed_state nsl_session_tag alice sess_id (InitiatorSentMsg3 bob n_a msg2.n_b <: nsl_session);* + set_state alice sess_id (InitiatorSentMsg3 bob n_a msg2.n_b <: nsl_session);* return (Some ()) ) | _ -> return None val send_msg3: nsl_global_sess_ids -> principal -> nat -> crypto (option nat) let send_msg3 global_sess_id alice sess_id = - let*? st: nsl_session = get_typed_state nsl_session_tag alice sess_id in + let*? st: nsl_session = get_state alice sess_id in match st with | InitiatorSentMsg3 bob n_a n_b -> ( let*? pk_b = get_public_key alice global_sess_id.pki (PkEnc "NSL.PublicKey") bob in @@ -132,12 +134,12 @@ val prepare_msg4: nsl_global_sess_ids -> principal -> nat -> nat -> crypto (opti let prepare_msg4 global_sess_id bob sess_id msg_id = let*? msg = recv_msg msg_id in let*? sk_b = get_private_key bob global_sess_id.private_keys (PkDec "NSL.PublicKey") in - let*? st: nsl_session = get_typed_state nsl_session_tag bob sess_id in + let*? st: nsl_session = get_state bob sess_id in match st with | ResponderSentMsg2 alice n_a n_b -> ( let*? msg3: message3 = return (decode_message3 alice bob msg sk_b n_b) in trigger_event bob (Respond2 alice bob n_a n_b);* - set_typed_state nsl_session_tag bob sess_id (ResponderReceivedMsg3 alice n_a n_b <: nsl_session);* + set_state bob sess_id (ResponderReceivedMsg3 alice n_a n_b <: nsl_session);* return (Some ()) ) | _ -> return None diff --git a/examples/nsl_pk/DY.Example.NSL.SecurityProperties.fst b/examples/nsl_pk/DY.Example.NSL.SecurityProperties.fst index cbe29a9..be0c8cf 100644 --- a/examples/nsl_pk/DY.Example.NSL.SecurityProperties.fst +++ b/examples/nsl_pk/DY.Example.NSL.SecurityProperties.fst @@ -62,9 +62,9 @@ val n_a_secrecy: (requires attacker_knows tr n_a /\ trace_invariant tr /\ ( - (exists sess_id. typed_state_was_set tr nsl_session_tag alice sess_id (InitiatorSentMsg1 bob n_a)) \/ - (exists sess_id n_b. typed_state_was_set tr nsl_session_tag alice sess_id (InitiatorSentMsg3 bob n_a n_b)) \/ - (exists sess_id n_b. typed_state_was_set tr nsl_session_tag bob sess_id (ResponderReceivedMsg3 alice n_a n_b)) + (exists sess_id. state_was_set tr alice sess_id (InitiatorSentMsg1 bob n_a)) \/ + (exists sess_id n_b. state_was_set tr alice sess_id (InitiatorSentMsg3 bob n_a n_b)) \/ + (exists sess_id n_b. state_was_set tr bob sess_id (ResponderReceivedMsg3 alice n_a n_b)) ) ) (ensures is_corrupt tr (principal_label alice) \/ is_corrupt tr (principal_label bob)) @@ -80,9 +80,9 @@ val n_b_secrecy: (requires attacker_knows tr n_b /\ trace_invariant tr /\ ( - (exists sess_id n_a. typed_state_was_set tr nsl_session_tag bob sess_id (ResponderSentMsg2 alice n_a n_b)) \/ - (exists sess_id n_a. typed_state_was_set tr nsl_session_tag bob sess_id (ResponderReceivedMsg3 alice n_a n_b)) \/ - (exists sess_id n_a. typed_state_was_set tr nsl_session_tag alice sess_id (InitiatorSentMsg3 bob n_a n_b)) + (exists sess_id n_a. state_was_set tr bob sess_id (ResponderSentMsg2 alice n_a n_b)) \/ + (exists sess_id n_a. state_was_set tr bob sess_id (ResponderReceivedMsg3 alice n_a n_b)) \/ + (exists sess_id n_a. state_was_set tr alice sess_id (InitiatorSentMsg3 bob n_a n_b)) ) ) (ensures is_corrupt tr (principal_label alice) \/ is_corrupt tr (principal_label bob)) diff --git a/src/lib/state/DY.Lib.State.Map.fst b/src/lib/state/DY.Lib.State.Map.fst index 9703f64..ba43ae5 100644 --- a/src/lib/state/DY.Lib.State.Map.fst +++ b/src/lib/state/DY.Lib.State.Map.fst @@ -15,75 +15,88 @@ open DY.Lib.State.Typed /// The parameters necessary to define the map functions. -noeq type map_types = { - key: eqtype; - ps_key: parser_serializer bytes key; - value: Type0; - ps_value: parser_serializer bytes value; +class map_types (key_t:eqtype) (value_t:Type0) = { + tag: string; + ps_key_t: parser_serializer bytes key_t; + ps_value_t: parser_serializer bytes value_t; } /// Type for the map predicate, which is used to define the state predicate. /// The map predicate relates a key and its associated value. -noeq type map_predicate {|crypto_invariants|} (mt:map_types) = { - pred: trace -> principal -> nat -> mt.key -> mt.value -> prop; - pred_later: tr1:trace -> tr2:trace -> prin:principal -> sess_id:nat -> key:mt.key -> value:mt.value -> Lemma +noeq type map_predicate {|crypto_invariants|} (key_t:eqtype) (value_t:Type0) {|mt:map_types key_t value_t|} = { + pred: trace -> principal -> nat -> key_t -> value_t -> prop; + pred_later: tr1:trace -> tr2:trace -> prin:principal -> sess_id:nat -> key:key_t -> value:value_t -> Lemma (requires pred tr1 prin sess_id key value /\ tr1 <$ tr2) (ensures pred tr2 prin sess_id key value) ; - pred_knowable: tr:trace -> prin:principal -> sess_id:nat -> key:mt.key -> value:mt.value -> Lemma + pred_knowable: tr:trace -> prin:principal -> sess_id:nat -> key:key_t -> value:value_t -> Lemma (requires pred tr prin sess_id key value) - (ensures is_well_formed_prefix mt.ps_key (is_knowable_by (principal_state_label prin sess_id) tr) key /\ is_well_formed_prefix mt.ps_value (is_knowable_by (principal_state_label prin sess_id) tr) value) + (ensures is_well_formed_prefix mt.ps_key_t (is_knowable_by (principal_state_label prin sess_id) tr) key /\ is_well_formed_prefix mt.ps_value_t (is_knowable_by (principal_state_label prin sess_id) tr) value) ; } [@@ with_bytes bytes] -noeq type map_elem (mt:map_types) = { - [@@@ with_parser #bytes mt.ps_key] - key: mt.key; - [@@@ with_parser #bytes mt.ps_value] - value: mt.value; +noeq type map_elem (key_t:eqtype) (value_t:Type0) {|mt:map_types key_t value_t|} = { + [@@@ with_parser #bytes mt.ps_key_t] + key: key_t; + [@@@ with_parser #bytes mt.ps_value_t] + value: value_t; } %splice [ps_map_elem] (gen_parser (`map_elem)) %splice [ps_map_elem_is_well_formed] (gen_is_well_formed_lemma (`map_elem)) [@@ with_bytes bytes] -noeq type map (mt:map_types) = { - [@@@ with_parser #bytes (ps_list (ps_map_elem mt))] - key_values: list (map_elem mt) +noeq type map (key_t:eqtype) (value_t:Type0) {|mt:map_types key_t value_t|} = { + [@@@ with_parser #bytes (ps_list (ps_map_elem key_t value_t))] + key_values: list (map_elem key_t value_t) } %splice [ps_map] (gen_parser (`map)) %splice [ps_map_is_well_formed] (gen_is_well_formed_lemma (`map)) -instance parseable_serializeable_map (mt:map_types) : parseable_serializeable bytes (map mt) = mk_parseable_serializeable (ps_map mt) +instance parseable_serializeable_map (key_t:eqtype) (value_t:Type0) {|map_types key_t value_t|} : parseable_serializeable bytes (map key_t value_t) = mk_parseable_serializeable (ps_map key_t value_t) -val map_elem_invariant: {|crypto_invariants|} -> #mt:map_types -> map_predicate mt -> trace -> principal -> nat -> map_elem mt -> prop -let map_elem_invariant #cinvs #mt mpred tr prin sess_id x = +instance map_has_local_state (key_t:eqtype) (value_t:Type0) {|mt:map_types key_t value_t|}: local_state (map key_t value_t) = { + tag = mt.tag; + format = (parseable_serializeable_map key_t value_t); +} + +val map_elem_invariant: + {|crypto_invariants|} -> + #key_t:eqtype -> #value_t:Type0 -> {|map_types key_t value_t|} -> + map_predicate key_t value_t -> + trace -> principal -> nat -> map_elem key_t value_t -> + prop +let map_elem_invariant #cinvs #key_t #value_t #mt mpred tr prin sess_id x = mpred.pred tr prin sess_id x.key x.value val map_invariant: - {|crypto_invariants|} -> #mt:map_types -> - map_predicate mt -> trace -> principal -> nat -> map mt -> + {|crypto_invariants|} -> + #key_t:eqtype -> #value_t:Type0 -> {|map_types key_t value_t|} -> + map_predicate key_t value_t -> + trace -> principal -> nat -> map key_t value_t -> prop -let map_invariant #cinvs #mt mpred tr prin sess_id st = +let map_invariant #cinvs #key_t #value_t #mt mpred tr prin sess_id st = for_allP (map_elem_invariant mpred tr prin sess_id) st.key_values val map_invariant_eq: - {|crypto_invariants|} -> #mt:map_types -> - mpred:map_predicate mt -> tr:trace -> prin:principal -> sess_id:nat -> st:map mt -> + {|crypto_invariants|} -> + #key_t:eqtype -> #value_t:Type0 -> {|map_types key_t value_t|} -> + mpred:map_predicate key_t value_t -> + tr:trace -> prin:principal -> sess_id:nat -> st:map key_t value_t -> Lemma (map_invariant mpred tr prin sess_id st <==> (forall x. List.Tot.memP x st.key_values ==> map_elem_invariant mpred tr prin sess_id x)) -let map_invariant_eq #cinvs #mt mpred tr prin sess_id st = +let map_invariant_eq #cinvs #key_t #value_t #mt mpred tr prin sess_id st = for_allP_eq (map_elem_invariant mpred tr prin sess_id) st.key_values val map_session_invariant: {|crypto_invariants|} -> - #mt:map_types -> - mpred:map_predicate mt -> - typed_session_pred (map mt) -let map_session_invariant #cinvs #mt mpred = { + #key_t:eqtype -> #value_t:Type0 -> {|map_types key_t value_t|} -> + mpred:map_predicate key_t value_t -> + local_state_predicate (map key_t value_t) +let map_session_invariant #cinvs #key_t #value_t #mt mpred = { pred = (fun tr prin sess_id content -> map_invariant mpred tr prin sess_id content); pred_later = (fun tr1 tr2 prin sess_id content -> map_invariant_eq mpred tr1 prin sess_id content; @@ -93,8 +106,8 @@ let map_session_invariant #cinvs #mt mpred = { pred_knowable = (fun tr prin sess_id content -> let pre = (is_knowable_by (principal_state_label prin sess_id) tr) in map_invariant_eq mpred tr prin sess_id content; - for_allP_eq (is_well_formed_prefix (ps_map_elem mt) pre) content.key_values; - introduce forall x. map_elem_invariant mpred tr prin sess_id x ==> is_well_formed_prefix (ps_map_elem mt) pre x + for_allP_eq (is_well_formed_prefix (ps_map_elem key_t value_t) pre) content.key_values; + introduce forall x. map_elem_invariant mpred tr prin sess_id x ==> is_well_formed_prefix (ps_map_elem key_t value_t) pre x with ( introduce _ ==> _ with _. ( mpred.pred_knowable tr prin sess_id x.key x.value @@ -103,43 +116,50 @@ let map_session_invariant #cinvs #mt mpred = { ); } -val has_map_session_invariant: #mt:map_types -> protocol_invariants -> (string & map_predicate mt) -> prop -let has_map_session_invariant #mt invs (tag, mpred) = - has_typed_session_pred invs (tag, (map_session_invariant mpred)) +val has_map_session_invariant: + #key_t:eqtype -> #value_t:Type0 -> {|map_types key_t value_t|} -> + protocol_invariants -> map_predicate key_t value_t -> prop +let has_map_session_invariant #key_t #value_t #mt invs mpred = + has_local_state_predicate invs (map_session_invariant mpred) (*** Map API ***) [@@ "opaque_to_smt"] val initialize_map: - mt:map_types -> tag:string -> prin:principal -> + key_t:eqtype -> value_t:Type0 -> + {|map_types key_t value_t|} -> + prin:principal -> crypto nat -let initialize_map mt tag prin = +let initialize_map key_t value_t #mt prin = let* sess_id = new_session_id prin in - let session: map mt = { key_values = [] } in - set_typed_state tag prin sess_id session;* + let session: map key_t value_t = { key_values = [] } in + set_state prin sess_id session;* return sess_id [@@ "opaque_to_smt"] val add_key_value: - mt:map_types -> tag:string -> + #key_t:eqtype -> #value_t:Type0 -> {|map_types key_t value_t|} -> prin:principal -> sess_id:nat -> - key:mt.key -> value:mt.value -> + key:key_t -> value:value_t -> crypto (option unit) -let add_key_value mt tag prin sess_id key value = - let*? the_map = get_typed_state tag prin sess_id in +let add_key_value #key_t #value_t #mt prin sess_id key value = + let*? the_map = get_state prin sess_id in let new_elem = {key; value;} in - set_typed_state tag prin sess_id { key_values = new_elem::the_map.key_values };* + set_state prin sess_id { key_values = new_elem::the_map.key_values };* return (Some ()) #push-options "--fuel 1 --ifuel 1" -val find_value_aux: #mt:map_types -> key:mt.key -> l:list (map_elem mt) -> Pure (option mt.value) +val find_value_aux: + #key_t:eqtype -> #value_t:Type0 -> {|map_types key_t value_t|} -> + key:key_t -> l:list (map_elem key_t value_t) -> + Pure (option value_t) (requires True) (ensures fun res -> match res with | None -> True | Some value -> List.Tot.memP ({key; value;}) l ) -let rec find_value_aux #mt key l = +let rec find_value_aux #key_t #value_t #mt key l = match l with | [] -> None | h::t -> @@ -153,79 +173,78 @@ let rec find_value_aux #mt key l = [@@ "opaque_to_smt"] val find_value: - mt:map_types -> tag:string -> + #key_t:eqtype -> #value_t:Type0 -> {|map_types key_t value_t|} -> prin:principal -> sess_id:nat -> - key:mt.key -> - crypto (option mt.value) -let find_value mt tag prin sess_id key = - let*? the_map = get_typed_state tag prin sess_id in + key:key_t -> + crypto (option value_t) +let find_value #key_t #value_t #mt prin sess_id key = + let*? the_map = get_state prin sess_id in return (find_value_aux key the_map.key_values) #push-options "--fuel 1" val initialize_map_invariant: {|invs:protocol_invariants|} -> - mt:map_types -> mpred:map_predicate mt -> tag:string -> + #key_t:eqtype -> #value_t:Type0 -> {|map_types key_t value_t|} -> + mpred:map_predicate key_t value_t -> prin:principal -> tr:trace -> Lemma (requires trace_invariant tr /\ - has_map_session_invariant invs (tag, mpred) + has_map_session_invariant invs mpred ) (ensures ( - let (_, tr_out) = initialize_map mt tag prin tr in + let (_, tr_out) = initialize_map key_t value_t prin tr in trace_invariant tr_out )) - [SMTPat (initialize_map mt tag prin tr); - SMTPat (has_map_session_invariant invs (tag, mpred)); + [SMTPat (initialize_map key_t value_t prin tr); + SMTPat (has_map_session_invariant invs mpred); SMTPat (trace_invariant tr) ] -let initialize_map_invariant #invs mt mpred tag prin tr = +let initialize_map_invariant #invs #key_t #value_t #mt mpred prin tr = reveal_opaque (`%initialize_map) (initialize_map) #pop-options #push-options "--fuel 1" val add_key_value_invariant: {|invs:protocol_invariants|} -> - mt:map_types -> mpred:map_predicate mt -> tag:string -> + #key_t:eqtype -> #value_t:Type0 -> {|map_types key_t value_t|} -> + mpred:map_predicate key_t value_t -> prin:principal -> sess_id:nat -> - key:mt.key -> value:mt.value -> + key:key_t -> value:value_t -> tr:trace -> Lemma (requires mpred.pred tr prin sess_id key value /\ trace_invariant tr /\ - has_map_session_invariant invs (tag, mpred) + has_map_session_invariant invs mpred ) (ensures ( - let (_, tr_out) = add_key_value mt tag prin sess_id key value tr in + let (_, tr_out) = add_key_value prin sess_id key value tr in trace_invariant tr_out )) - [SMTPat (add_key_value mt tag prin sess_id key value tr); - SMTPat (has_map_session_invariant invs (tag, mpred)); + [SMTPat (add_key_value prin sess_id key value tr); + SMTPat (has_map_session_invariant invs mpred); SMTPat (trace_invariant tr) ] -let add_key_value_invariant #invs mt mpred tag prin sess_id key value tr = - reveal_opaque (`%add_key_value) (add_key_value); - let (opt_the_map, tr) = get_typed_state #(map mt) tag prin sess_id tr in - match opt_the_map with - | None -> () - | Some the_map -> () +let add_key_value_invariant #invs #key_t #value_t #mt mpred prin sess_id key value tr = + reveal_opaque (`%add_key_value) (add_key_value #key_t #value_t) #pop-options val find_value_invariant: {|invs:protocol_invariants|} -> - mt:map_types -> mpred:map_predicate mt -> tag:string -> + #key_t:eqtype -> #value_t:Type0 -> {|map_types key_t value_t|} -> + mpred:map_predicate key_t value_t -> prin:principal -> sess_id:nat -> - key:mt.key -> + key:key_t -> tr:trace -> Lemma (requires trace_invariant tr /\ - has_map_session_invariant invs (tag, mpred) + has_map_session_invariant invs mpred ) (ensures ( - let (opt_value, tr_out) = find_value mt tag prin sess_id key tr in + let (opt_value, tr_out) = find_value prin sess_id key tr in tr_out == tr /\ ( match opt_value with | None -> True @@ -234,13 +253,13 @@ val find_value_invariant: ) ) )) - [SMTPat (find_value mt tag prin sess_id key tr); - SMTPat (has_map_session_invariant invs (tag, mpred)); + [SMTPat (find_value #key_t #value_t prin sess_id key tr); + SMTPat (has_map_session_invariant invs mpred); SMTPat (trace_invariant tr); ] -let find_value_invariant #invs mt mpred tag prin sess_id key tr = - reveal_opaque (`%find_value) (find_value); - let (opt_the_map, tr) = get_typed_state #(map mt) tag prin sess_id tr in +let find_value_invariant #invs #key_t #value_t #mt mpred prin sess_id key tr = + reveal_opaque (`%find_value) (find_value #key_t #value_t); + let (opt_the_map, tr) = get_state prin sess_id tr in match opt_the_map with | None -> () | Some the_map -> ( diff --git a/src/lib/state/DY.Lib.State.PKI.fst b/src/lib/state/DY.Lib.State.PKI.fst index 3f92e53..5a85c14 100644 --- a/src/lib/state/DY.Lib.State.PKI.fst +++ b/src/lib/state/DY.Lib.State.PKI.fst @@ -4,6 +4,8 @@ open Comparse open DY.Core open DY.Lib.Comparse.Glue open DY.Lib.Comparse.Parsers +open DY.Lib.State.Tagged +open DY.Lib.State.Typed open DY.Lib.State.Map #set-options "--fuel 1 --ifuel 1" @@ -38,19 +40,18 @@ type pki_key = { %splice [ps_pki_key] (gen_parser (`pki_key)) %splice [ps_pki_key_is_well_formed] (gen_is_well_formed_lemma (`pki_key)) -type pki_value (bytes:Type0) {|bytes_like bytes|} = { +[@@ with_bytes bytes] +type pki_value = { public_key: bytes; } %splice [ps_pki_value] (gen_parser (`pki_value)) %splice [ps_pki_value_is_well_formed] (gen_is_well_formed_lemma (`pki_value)) -val pki_types: map_types -let pki_types = { - key = pki_key; - ps_key = ps_pki_key; - value = pki_value bytes; - ps_value = ps_pki_value; +instance map_types_pki: map_types pki_key pki_value = { + tag = "DY.Lib.State.PKI"; + ps_key_t = ps_pki_key; + ps_value_t = ps_pki_value; } val is_public_key_for: @@ -65,37 +66,38 @@ let is_public_key_for #cinvs tr pk pk_type who = is_verification_key usg (principal_label who) tr pk ) -val pki_pred: {|crypto_invariants|} -> map_predicate pki_types +// The `#_` at the end is a workaround for FStarLang/FStar#3286 +val pki_pred: {|crypto_invariants|} -> map_predicate pki_key pki_value #_ let pki_pred #cinvs = { - pred = (fun tr prin sess_id (key:pki_types.key) value -> + pred = (fun tr prin sess_id key value -> is_public_key_for tr value.public_key key.ty key.who ); pred_later = (fun tr1 tr2 prin sess_id key value -> ()); pred_knowable = (fun tr prin sess_id key value -> ()); } -val pki_tag: string -let pki_tag = "DY.Lib.State.PKI" - val has_pki_invariant: protocol_invariants -> prop let has_pki_invariant invs = - has_map_session_invariant invs (pki_tag, pki_pred) + has_map_session_invariant invs pki_pred + +val pki_tag_and_invariant: {|crypto_invariants|} -> string & local_bytes_state_predicate +let pki_tag_and_invariant #ci = (map_types_pki.tag, local_state_predicate_to_local_bytes_state_predicate (map_session_invariant pki_pred)) (*** PKI API ***) [@@ "opaque_to_smt"] val initialize_pki: prin:principal -> crypto nat -let initialize_pki = initialize_map pki_types pki_tag +let initialize_pki = initialize_map pki_key pki_value #_ // another workaround for FStarLang/FStar#3286 [@@ "opaque_to_smt"] val install_public_key: principal -> nat -> public_key_type -> principal -> bytes -> crypto (option unit) let install_public_key prin sess_id pk_type who pk = - add_key_value pki_types pki_tag prin sess_id ({ty = pk_type; who;}) ({public_key = pk;}) + add_key_value prin sess_id ({ty = pk_type; who;}) ({public_key = pk;}) [@@ "opaque_to_smt"] val get_public_key: principal -> nat -> public_key_type -> principal -> crypto (option bytes) let get_public_key prin sess_id pk_type who = - let*? res = find_value pki_types pki_tag prin sess_id ({ty = pk_type; who;}) in + let*? res = find_value prin sess_id ({ty = pk_type; who;}) in return (Some res.public_key) val initialize_pki_invariant: diff --git a/src/lib/state/DY.Lib.State.PrivateKeys.fst b/src/lib/state/DY.Lib.State.PrivateKeys.fst index 766b123..f02b4bf 100644 --- a/src/lib/state/DY.Lib.State.PrivateKeys.fst +++ b/src/lib/state/DY.Lib.State.PrivateKeys.fst @@ -4,6 +4,8 @@ open Comparse open DY.Core open DY.Lib.Comparse.Glue open DY.Lib.Comparse.Parsers +open DY.Lib.State.Tagged +open DY.Lib.State.Typed open DY.Lib.State.Map #set-options "--fuel 1 --ifuel 1" @@ -22,19 +24,26 @@ type private_key_type = %splice [ps_private_key_type] (gen_parser (`private_key_type)) %splice [ps_private_key_type_is_well_formed] (gen_is_well_formed_lemma (`private_key_type)) -type private_key_value (bytes:Type0) {|bytes_like bytes|} = { +[@@ with_bytes bytes] +type private_key_key = { + ty:private_key_type; +} + +%splice [ps_private_key_key] (gen_parser (`private_key_key)) +%splice [ps_private_key_key_is_well_formed] (gen_is_well_formed_lemma (`private_key_key)) + +[@@ with_bytes bytes] +type private_key_value = { private_key: bytes; } %splice [ps_private_key_value] (gen_parser (`private_key_value)) %splice [ps_private_key_value_is_well_formed] (gen_is_well_formed_lemma (`private_key_value)) -val private_keys_types: map_types -let private_keys_types = { - key = private_key_type; - ps_key = ps_private_key_type; - value = private_key_value bytes; - ps_value = ps_private_key_value; +instance map_types_private_keys: map_types private_key_key private_key_value = { + tag = "DY.Lib.State.PrivateKeys"; + ps_key_t = ps_private_key_key; + ps_value_t = ps_private_key_value; } val is_private_key_for: @@ -49,21 +58,22 @@ let is_private_key_for #cinvs tr sk sk_type who = is_signature_key usg (principal_label who) tr sk ) -val private_keys_pred: {|crypto_invariants|} -> map_predicate private_keys_types +// The `#_` at the end is a workaround for FStarLang/FStar#3286 +val private_keys_pred: {|crypto_invariants|} -> map_predicate private_key_key private_key_value #_ let private_keys_pred #cinvs = { - pred = (fun tr prin sess_id (key:private_keys_types.key) value -> - is_private_key_for tr value.private_key key prin + pred = (fun tr prin sess_id key value -> + is_private_key_for tr value.private_key key.ty prin ); pred_later = (fun tr1 tr2 prin sess_id key value -> ()); pred_knowable = (fun tr prin sess_id key value -> ()); } -val private_keys_tag: string -let private_keys_tag = "DY.Lib.State.PrivateKeys" - val has_private_keys_invariant: protocol_invariants -> prop let has_private_keys_invariant invs = - has_map_session_invariant invs (private_keys_tag, private_keys_pred) + has_map_session_invariant invs private_keys_pred + +val private_keys_tag_and_invariant: {|crypto_invariants|} -> string & local_bytes_state_predicate +let private_keys_tag_and_invariant #ci = (map_types_private_keys.tag, local_state_predicate_to_local_bytes_state_predicate (map_session_invariant private_keys_pred)) val private_key_type_to_usage: private_key_type -> @@ -77,18 +87,18 @@ let private_key_type_to_usage sk_type = [@@ "opaque_to_smt"] val initialize_private_keys: prin:principal -> crypto nat -let initialize_private_keys = initialize_map private_keys_types private_keys_tag +let initialize_private_keys = initialize_map private_key_key private_key_value #_ // another workaround for FStarLang/FStar#3286 [@@ "opaque_to_smt"] val generate_private_key: principal -> nat -> private_key_type -> crypto (option unit) let generate_private_key prin sess_id sk_type = let* sk = mk_rand (private_key_type_to_usage sk_type) (principal_label prin) 64 in //TODO - add_key_value private_keys_types private_keys_tag prin sess_id sk_type ({private_key = sk;}) + add_key_value prin sess_id ({ty = sk_type}) ({private_key = sk;}) [@@ "opaque_to_smt"] val get_private_key: principal -> nat -> private_key_type -> crypto (option bytes) let get_private_key prin sess_id sk_type = - let*? res = find_value private_keys_types private_keys_tag prin sess_id sk_type in + let*? res = find_value prin sess_id ({ty = sk_type}) in return (Some res.private_key) val initialize_private_keys_invariant: diff --git a/src/lib/state/DY.Lib.State.Tagged.fst b/src/lib/state/DY.Lib.State.Tagged.fst index dbfd9c1..4d7f843 100644 --- a/src/lib/state/DY.Lib.State.Tagged.fst +++ b/src/lib/state/DY.Lib.State.Tagged.fst @@ -8,22 +8,22 @@ open DY.Lib.Comparse.Parsers #set-options "--fuel 1 --ifuel 1" -(*** Session predicates ***) +(*** Tagged state predicates ***) [@@ with_bytes bytes] -type session = { +type tagged_state = { [@@@ with_parser #bytes ps_string] tag: string; content: bytes; } -%splice [ps_session] (gen_parser (`session)) -%splice [ps_session_is_well_formed] (gen_is_well_formed_lemma (`session)) +%splice [ps_tagged_state] (gen_parser (`tagged_state)) +%splice [ps_tagged_state_is_well_formed] (gen_is_well_formed_lemma (`tagged_state)) -instance parseable_serializeable_session: parseable_serializeable bytes session = mk_parseable_serializeable (ps_session) +instance parseable_serializeable_tagged_state: parseable_serializeable bytes tagged_state = mk_parseable_serializeable (ps_tagged_state) noeq -type session_pred {|crypto_invariants|} = { +type local_bytes_state_predicate {|crypto_invariants|} = { pred: trace -> principal -> nat -> bytes -> prop; pred_later: tr1:trace -> tr2:trace -> @@ -40,14 +40,14 @@ type session_pred {|crypto_invariants|} = { ; } -let split_session_pred_func {|crypto_invariants|} : split_predicate_input_values = { +let split_local_bytes_state_predicate_func {|crypto_invariants|} : split_predicate_input_values = { tagged_data_t = trace & principal & nat & bytes; tag_t = string; encoded_tag_t = string; raw_data_t = trace & principal & nat & bytes; decode_tagged_data = (fun (tr, prin, sess_id, sess_content) -> ( - match parse session sess_content with + match parse tagged_state sess_content with | Some ({tag; content}) -> Some (tag, (tr, prin, sess_id, content)) | None -> None )); @@ -55,7 +55,7 @@ let split_session_pred_func {|crypto_invariants|} : split_predicate_input_values encode_tag = (fun s -> s); encode_tag_inj = (fun l1 l2 -> ()); - local_pred = session_pred; + local_pred = local_bytes_state_predicate; global_pred = trace -> principal -> nat -> bytes -> prop; apply_local_pred = (fun spred (tr, prin, sess_id, content) -> @@ -70,71 +70,71 @@ let split_session_pred_func {|crypto_invariants|} : split_predicate_input_values apply_mk_global_pred = (fun spred x -> ()); } -val has_session_pred: protocol_invariants -> (string & session_pred) -> prop -let has_session_pred invs (tag, spred) = - has_local_pred split_session_pred_func (state_pred) (tag, spred) +val has_local_bytes_state_predicate: protocol_invariants -> (string & local_bytes_state_predicate) -> prop +let has_local_bytes_state_predicate invs (tag, spred) = + has_local_pred split_local_bytes_state_predicate_func (state_pred) (tag, spred) -(*** Global session predicate builder ***) +(*** Global tagged state predicate builder ***) -val mk_global_session_pred: {|crypto_invariants|} -> list (string & session_pred) -> trace -> principal -> nat -> bytes -> prop -let mk_global_session_pred #cinvs l = - mk_global_pred split_session_pred_func l +val mk_global_local_bytes_state_predicate: {|crypto_invariants|} -> list (string & local_bytes_state_predicate) -> trace -> principal -> nat -> bytes -> prop +let mk_global_local_bytes_state_predicate #cinvs l = + mk_global_pred split_local_bytes_state_predicate_func l -val mk_global_session_pred_correct: invs:protocol_invariants -> lpreds:list (string & session_pred) -> Lemma +val mk_global_local_bytes_state_predicate_correct: invs:protocol_invariants -> lpreds:list (string & local_bytes_state_predicate) -> Lemma (requires - invs.trace_invs.state_pred.pred == mk_global_session_pred lpreds /\ + invs.trace_invs.state_pred.pred == mk_global_local_bytes_state_predicate lpreds /\ List.Tot.no_repeats_p (List.Tot.map fst lpreds) ) - (ensures for_allP (has_session_pred invs) lpreds) -let mk_global_session_pred_correct invs lpreds = - for_allP_eq (has_session_pred invs) lpreds; - FStar.Classical.forall_intro_2 (FStar.Classical.move_requires_2 (mk_global_pred_correct split_session_pred_func lpreds)) + (ensures for_allP (has_local_bytes_state_predicate invs) lpreds) +let mk_global_local_bytes_state_predicate_correct invs lpreds = + for_allP_eq (has_local_bytes_state_predicate invs) lpreds; + FStar.Classical.forall_intro_2 (FStar.Classical.move_requires_2 (mk_global_pred_correct split_local_bytes_state_predicate_func lpreds)) -val mk_global_session_pred_later: - cinvs:crypto_invariants -> lpreds:list (string & session_pred) -> +val mk_global_local_bytes_state_predicate_later: + cinvs:crypto_invariants -> lpreds:list (string & local_bytes_state_predicate) -> tr1:trace -> tr2:trace -> prin:principal -> sess_id:nat -> full_content:bytes -> Lemma - (requires mk_global_session_pred lpreds tr1 prin sess_id full_content /\ tr1 <$ tr2) - (ensures mk_global_session_pred lpreds tr2 prin sess_id full_content) -let mk_global_session_pred_later cinvs lpreds tr1 tr2 prin sess_id full_content = - mk_global_pred_eq split_session_pred_func lpreds (tr1, prin, sess_id, full_content); + (requires mk_global_local_bytes_state_predicate lpreds tr1 prin sess_id full_content /\ tr1 <$ tr2) + (ensures mk_global_local_bytes_state_predicate lpreds tr2 prin sess_id full_content) +let mk_global_local_bytes_state_predicate_later cinvs lpreds tr1 tr2 prin sess_id full_content = + mk_global_pred_eq split_local_bytes_state_predicate_func lpreds (tr1, prin, sess_id, full_content); eliminate exists tag lpred raw_data. List.Tot.memP (tag, lpred) lpreds /\ - split_session_pred_func.apply_local_pred lpred raw_data /\ - split_session_pred_func.decode_tagged_data (tr1, prin, sess_id, full_content) == Some (split_session_pred_func.encode_tag tag, raw_data) - returns mk_global_session_pred lpreds tr2 prin sess_id full_content + split_local_bytes_state_predicate_func.apply_local_pred lpred raw_data /\ + split_local_bytes_state_predicate_func.decode_tagged_data (tr1, prin, sess_id, full_content) == Some (split_local_bytes_state_predicate_func.encode_tag tag, raw_data) + returns mk_global_local_bytes_state_predicate lpreds tr2 prin sess_id full_content with _. ( - let Some (_, (_, _, _, content)) = split_session_pred_func.decode_tagged_data (tr1, prin, sess_id, full_content) in + let Some (_, (_, _, _, content)) = split_local_bytes_state_predicate_func.decode_tagged_data (tr1, prin, sess_id, full_content) in lpred.pred_later tr1 tr2 prin sess_id content; - mk_global_pred_eq split_session_pred_func lpreds (tr2, prin, sess_id, full_content); - assert(split_session_pred_func.apply_local_pred lpred (tr2, prin, sess_id, content)) + mk_global_pred_eq split_local_bytes_state_predicate_func lpreds (tr2, prin, sess_id, full_content); + assert(split_local_bytes_state_predicate_func.apply_local_pred lpred (tr2, prin, sess_id, content)) ) -val mk_global_session_pred_knowable: - cinvs:crypto_invariants -> lpreds:list (string & session_pred) -> +val mk_global_local_bytes_state_predicate_knowable: + cinvs:crypto_invariants -> lpreds:list (string & local_bytes_state_predicate) -> tr:trace -> prin:principal -> sess_id:nat -> full_content:bytes -> Lemma - (requires mk_global_session_pred lpreds tr prin sess_id full_content) + (requires mk_global_local_bytes_state_predicate lpreds tr prin sess_id full_content) (ensures is_knowable_by (principal_state_label prin sess_id) tr full_content) -let mk_global_session_pred_knowable cinvs lpreds tr prin sess_id full_content = - mk_global_pred_eq split_session_pred_func lpreds (tr, prin, sess_id, full_content); +let mk_global_local_bytes_state_predicate_knowable cinvs lpreds tr prin sess_id full_content = + mk_global_pred_eq split_local_bytes_state_predicate_func lpreds (tr, prin, sess_id, full_content); eliminate exists tag lpred raw_data. List.Tot.memP (tag, lpred) lpreds /\ - split_session_pred_func.apply_local_pred lpred raw_data /\ - split_session_pred_func.decode_tagged_data (tr, prin, sess_id, full_content) == Some (split_session_pred_func.encode_tag tag, raw_data) + split_local_bytes_state_predicate_func.apply_local_pred lpred raw_data /\ + split_local_bytes_state_predicate_func.decode_tagged_data (tr, prin, sess_id, full_content) == Some (split_local_bytes_state_predicate_func.encode_tag tag, raw_data) returns is_knowable_by (principal_state_label prin sess_id) tr full_content with _. ( - let Some (tag, (_, _, _, content)) = split_session_pred_func.decode_tagged_data (tr, prin, sess_id, full_content) in + let Some (tag, (_, _, _, content)) = split_local_bytes_state_predicate_func.decode_tagged_data (tr, prin, sess_id, full_content) in lpred.pred_knowable tr prin sess_id content; - serialize_parse_inv_lemma session full_content; - serialize_wf_lemma session (is_knowable_by (principal_state_label prin sess_id) tr) ({tag; content}) + serialize_parse_inv_lemma tagged_state full_content; + serialize_wf_lemma tagged_state (is_knowable_by (principal_state_label prin sess_id) tr) ({tag; content}) ) -val mk_state_predicate: cinvs:crypto_invariants -> list (string & session_pred) -> state_predicate cinvs +val mk_state_predicate: cinvs:crypto_invariants -> list (string & local_bytes_state_predicate) -> state_predicate cinvs let mk_state_predicate cinvs lpreds = { - pred = mk_global_session_pred lpreds; - pred_later = mk_global_session_pred_later cinvs lpreds; - pred_knowable = mk_global_session_pred_knowable cinvs lpreds; + pred = mk_global_local_bytes_state_predicate lpreds; + pred_later = mk_global_local_bytes_state_predicate_later cinvs lpreds; + pred_knowable = mk_global_local_bytes_state_predicate_knowable cinvs lpreds; } (*** Predicates on trace ***) @@ -143,7 +143,7 @@ let mk_state_predicate cinvs lpreds = val tagged_state_was_set: trace -> string -> principal -> nat -> bytes -> prop let tagged_state_was_set tr tag prin sess_id content = let full_content = {tag; content;} in - let full_content_bytes = serialize session full_content in + let full_content_bytes = serialize tagged_state full_content in state_was_set tr prin sess_id full_content_bytes (*** API for tagged sessions ***) @@ -152,14 +152,14 @@ let tagged_state_was_set tr tag prin sess_id content = val set_tagged_state: string -> principal -> nat -> bytes -> crypto unit let set_tagged_state tag prin sess_id content = let full_content = {tag; content;} in - let full_content_bytes = serialize session full_content in + let full_content_bytes = serialize tagged_state full_content in set_state prin sess_id full_content_bytes [@@ "opaque_to_smt"] val get_tagged_state: string -> principal -> nat -> crypto (option bytes) let get_tagged_state the_tag prin sess_id = let*? full_content_bytes = get_state prin sess_id in - match parse session full_content_bytes with + match parse tagged_state full_content_bytes with | None -> return None | Some ({tag; content;}) -> if tag = the_tag then return (Some content) @@ -167,13 +167,13 @@ let get_tagged_state the_tag prin sess_id = val set_tagged_state_invariant: invs:protocol_invariants -> - tag:string -> spred:session_pred -> + tag:string -> spred:local_bytes_state_predicate -> prin:principal -> sess_id:nat -> content:bytes -> tr:trace -> Lemma (requires spred.pred tr prin sess_id content /\ trace_invariant tr /\ - has_session_pred invs (tag, spred) + has_local_bytes_state_predicate invs (tag, spred) ) (ensures ( let ((), tr_out) = set_tagged_state tag prin sess_id content tr in @@ -182,22 +182,22 @@ val set_tagged_state_invariant: )) [SMTPat (set_tagged_state tag prin sess_id content tr); SMTPat (trace_invariant tr); - SMTPat (has_session_pred invs (tag, spred))] + SMTPat (has_local_bytes_state_predicate invs (tag, spred))] let set_tagged_state_invariant invs tag spred prin sess_id content tr = reveal_opaque (`%set_tagged_state) (set_tagged_state); reveal_opaque (`%tagged_state_was_set) (tagged_state_was_set); let full_content = {tag; content;} in - parse_serialize_inv_lemma #bytes session full_content; - local_eq_global_lemma split_session_pred_func state_pred tag spred (tr, prin, sess_id, serialize _ full_content) (tr, prin, sess_id, content) + parse_serialize_inv_lemma #bytes tagged_state full_content; + local_eq_global_lemma split_local_bytes_state_predicate_func state_pred tag spred (tr, prin, sess_id, serialize _ full_content) (tr, prin, sess_id, content) val get_tagged_state_invariant: invs:protocol_invariants -> - tag:string -> spred:session_pred -> + tag:string -> spred:local_bytes_state_predicate -> prin:principal -> sess_id:nat -> tr:trace -> Lemma (requires trace_invariant tr /\ - has_session_pred invs (tag, spred) + has_local_bytes_state_predicate invs (tag, spred) ) (ensures ( let (opt_content, tr_out) = get_tagged_state tag prin sess_id tr in @@ -211,7 +211,7 @@ val get_tagged_state_invariant: )) [SMTPat (get_tagged_state tag prin sess_id tr); SMTPat (trace_invariant tr); - SMTPat (has_session_pred invs (tag, spred))] + SMTPat (has_local_bytes_state_predicate invs (tag, spred))] let get_tagged_state_invariant invs tag spred prin sess_id tr = reveal_opaque (`%get_tagged_state) (get_tagged_state); let (opt_content, tr_out) = get_tagged_state tag prin sess_id tr in @@ -219,28 +219,28 @@ let get_tagged_state_invariant invs tag spred prin sess_id tr = | None -> () | Some content -> let (Some full_content_bytes, tr) = get_state prin sess_id tr in - local_eq_global_lemma split_session_pred_func state_pred tag spred (tr, prin, sess_id, full_content_bytes) (tr, prin, sess_id, content) + local_eq_global_lemma split_local_bytes_state_predicate_func state_pred tag spred (tr, prin, sess_id, full_content_bytes) (tr, prin, sess_id, content) (*** Theorem ***) val tagged_state_was_set_implies_pred: invs:protocol_invariants -> tr:trace -> - tag:string -> spred:session_pred -> + tag:string -> spred:local_bytes_state_predicate -> prin:principal -> sess_id:nat -> content:bytes -> Lemma (requires tagged_state_was_set tr tag prin sess_id content /\ trace_invariant tr /\ - has_session_pred invs (tag, spred) + has_local_bytes_state_predicate invs (tag, spred) ) (ensures spred.pred tr prin sess_id content) [SMTPat (tagged_state_was_set tr tag prin sess_id content); SMTPat (trace_invariant tr); - SMTPat (has_session_pred invs (tag, spred)); + SMTPat (has_local_bytes_state_predicate invs (tag, spred)); ] let tagged_state_was_set_implies_pred invs tr tag spred prin sess_id content = reveal_opaque (`%tagged_state_was_set) (tagged_state_was_set); let full_content = {tag; content;} in - parse_serialize_inv_lemma #bytes session full_content; - let full_content_bytes: bytes = serialize session full_content in - local_eq_global_lemma split_session_pred_func state_pred tag spred (tr, prin, sess_id, full_content_bytes) (tr, prin, sess_id, content) + parse_serialize_inv_lemma #bytes tagged_state full_content; + let full_content_bytes: bytes = serialize tagged_state full_content in + local_eq_global_lemma split_local_bytes_state_predicate_func state_pred tag spred (tr, prin, sess_id, full_content_bytes) (tr, prin, sess_id, content) diff --git a/src/lib/state/DY.Lib.State.Typed.fst b/src/lib/state/DY.Lib.State.Typed.fst index e115d0c..929ea78 100644 --- a/src/lib/state/DY.Lib.State.Typed.fst +++ b/src/lib/state/DY.Lib.State.Typed.fst @@ -5,8 +5,22 @@ open DY.Core open DY.Lib.Comparse.Glue open DY.Lib.State.Tagged +class local_state (a:Type0) = { + tag: string; + [@@@FStar.Tactics.Typeclasses.tcinstance] + format: parseable_serializeable bytes a; +} + +val mk_local_state_instance: + #a:Type0 -> {|parseable_serializeable bytes a|} -> string -> + local_state a +let mk_local_state_instance #a #format tag = { + tag; + format; +} + noeq -type typed_session_pred {|crypto_invariants|} (a:Type) {|parseable_serializeable bytes a|} = { +type local_state_predicate {|crypto_invariants|} (a:Type) {|parseable_serializeable bytes a|} = { pred: trace -> principal -> nat -> a -> prop; pred_later: tr1:trace -> tr2:trace -> @@ -23,11 +37,11 @@ type typed_session_pred {|crypto_invariants|} (a:Type) {|parseable_serializeable ; } -val typed_session_pred_to_session_pred: +val local_state_predicate_to_local_bytes_state_predicate: {|crypto_invariants|} -> #a:Type -> {|parseable_serializeable bytes a|} -> - typed_session_pred a -> session_pred -let typed_session_pred_to_session_pred #cinvs #a #ps_a tspred = + local_state_predicate a -> local_bytes_state_predicate +let local_state_predicate_to_local_bytes_state_predicate #cinvs #a #ps_a tspred = { pred = (fun tr prin sess_id content_bytes -> match parse a content_bytes with @@ -46,74 +60,74 @@ let typed_session_pred_to_session_pred #cinvs #a #ps_a tspred = ); } -val has_typed_session_pred: - #a:Type -> {|parseable_serializeable bytes a|} -> - invs:protocol_invariants -> (string & typed_session_pred a) -> +val has_local_state_predicate: + #a:Type -> {|local_state a|} -> + invs:protocol_invariants -> local_state_predicate a -> prop -let has_typed_session_pred #a #ps_a invs (tag, spred) = - has_session_pred invs (tag, (typed_session_pred_to_session_pred spred)) +let has_local_state_predicate #a #ls invs spred = + has_local_bytes_state_predicate invs (ls.tag, (local_state_predicate_to_local_bytes_state_predicate spred)) [@@ "opaque_to_smt"] -val typed_state_was_set: - #a:Type -> {|parseable_serializeable bytes a|} -> - trace -> string -> principal -> nat -> a -> +val state_was_set: + #a:Type -> {|local_state a|} -> + trace -> principal -> nat -> a -> prop -let typed_state_was_set #a #ps_a tr tag prin sess_id content = - tagged_state_was_set tr tag prin sess_id (serialize _ content) +let state_was_set #a #ls tr prin sess_id content = + tagged_state_was_set tr ls.tag prin sess_id (serialize _ content) [@@ "opaque_to_smt"] -val set_typed_state: - #a:Type -> {|parseable_serializeable bytes a|} -> - string -> principal -> nat -> a -> crypto unit -let set_typed_state tag prin sess_id content = - set_tagged_state tag prin sess_id (serialize _ content) +val set_state: + #a:Type -> {|local_state a|} -> + principal -> nat -> a -> crypto unit +let set_state #a #ls prin sess_id content = + set_tagged_state ls.tag prin sess_id (serialize _ content) [@@ "opaque_to_smt"] -val get_typed_state: - #a:Type -> {|parseable_serializeable bytes a|} -> - string -> principal -> nat -> crypto (option a) -let get_typed_state #a tag prin sess_id = - let*? content_bytes = get_tagged_state tag prin sess_id in +val get_state: + #a:Type -> {|local_state a|} -> + principal -> nat -> crypto (option a) +let get_state #a #ls prin sess_id = + let*? content_bytes = get_tagged_state ls.tag prin sess_id in match parse a content_bytes with | None -> return None | Some content -> return (Some content) -val set_typed_state_invariant: - #a:Type -> {|parseable_serializeable bytes a|} -> +val set_state_invariant: + #a:Type -> {|local_state a|} -> {|invs:protocol_invariants|} -> - tag:string -> spred:typed_session_pred a -> + spred:local_state_predicate a -> prin:principal -> sess_id:nat -> content:a -> tr:trace -> Lemma (requires spred.pred tr prin sess_id content /\ trace_invariant tr /\ - has_typed_session_pred invs (tag, spred) + has_local_state_predicate invs spred ) (ensures ( - let ((), tr_out) = set_typed_state tag prin sess_id content tr in + let ((), tr_out) = set_state prin sess_id content tr in trace_invariant tr_out /\ - typed_state_was_set tr_out tag prin sess_id content + state_was_set tr_out prin sess_id content )) - [SMTPat (set_typed_state tag prin sess_id content tr); + [SMTPat (set_state prin sess_id content tr); SMTPat (trace_invariant tr); - SMTPat (has_typed_session_pred invs (tag, spred))] -let set_typed_state_invariant #a #ps_a #invs tag spred prin sess_id content tr = - reveal_opaque (`%set_typed_state) (set_typed_state #a); - reveal_opaque (`%typed_state_was_set) (typed_state_was_set #a); + SMTPat (has_local_state_predicate invs spred)] +let set_state_invariant #a #ls #invs spred prin sess_id content tr = + reveal_opaque (`%set_state) (set_state #a); + reveal_opaque (`%state_was_set) (state_was_set #a); parse_serialize_inv_lemma #bytes a content -val get_typed_state_invariant: - #a:Type -> {|parseable_serializeable bytes a|} -> +val get_state_invariant: + #a:Type -> {|local_state a|} -> {|invs:protocol_invariants|} -> - tag:string -> spred:typed_session_pred a -> + spred:local_state_predicate a -> prin:principal -> sess_id:nat -> tr:trace -> Lemma (requires trace_invariant tr /\ - has_typed_session_pred invs (tag, spred) + has_local_state_predicate invs spred ) (ensures ( - let (opt_content, tr_out) = get_typed_state tag prin sess_id tr in + let (opt_content, tr_out) = get_state prin sess_id tr in tr == tr_out /\ ( match opt_content with | None -> True @@ -122,28 +136,28 @@ val get_typed_state_invariant: ) ) )) - [SMTPat (get_typed_state #a tag prin sess_id tr); + [SMTPat (get_state #a prin sess_id tr); SMTPat (trace_invariant tr); - SMTPat (has_typed_session_pred invs (tag, spred))] -let get_typed_state_invariant #a #ps_a #invs tag spred prin sess_id tr = - reveal_opaque (`%get_typed_state) (get_typed_state #a) + SMTPat (has_local_state_predicate invs spred)] +let get_state_invariant #a #ls #invs spred prin sess_id tr = + reveal_opaque (`%get_state) (get_state #a) -val typed_state_was_set_implies_pred: - #a:Type -> {|parseable_serializeable bytes a|} -> +val state_was_set_implies_pred: + #a:Type -> {|local_state a|} -> invs:protocol_invariants -> tr:trace -> - tag:string -> spred:typed_session_pred a -> + spred:local_state_predicate a -> prin:principal -> sess_id:nat -> content:a -> Lemma (requires - typed_state_was_set tr tag prin sess_id content /\ + state_was_set tr prin sess_id content /\ trace_invariant tr /\ - has_typed_session_pred invs (tag, spred) + has_local_state_predicate invs spred ) (ensures spred.pred tr prin sess_id content) - [SMTPat (typed_state_was_set tr tag prin sess_id content); + [SMTPat (state_was_set tr prin sess_id content); SMTPat (trace_invariant tr); - SMTPat (has_typed_session_pred invs (tag, spred)); + SMTPat (has_local_state_predicate invs spred); ] -let typed_state_was_set_implies_pred #a #ps_a invs tr tag spred prin sess_id content = +let state_was_set_implies_pred #a #ls invs tr spred prin sess_id content = parse_serialize_inv_lemma #bytes a content; - reveal_opaque (`%typed_state_was_set) (typed_state_was_set #a) + reveal_opaque (`%state_was_set) (state_was_set #a) diff --git a/src/lib/utils/DY.Lib.Printing.fst b/src/lib/utils/DY.Lib.Printing.fst index 07a281c..948bf02 100644 --- a/src/lib/utils/DY.Lib.Printing.fst +++ b/src/lib/utils/DY.Lib.Printing.fst @@ -104,13 +104,14 @@ let private_key_type_to_string t = | DY.Lib.State.PrivateKeys.PkDec u -> "PkDec " ^ u | DY.Lib.State.PrivateKeys.Sign u -> "Sign " ^ u -val private_keys_types_to_string: (list (map_elem DY.Lib.State.PrivateKeys.private_keys_types)) -> string +// The `#_` at the end is a workaround for FStarLang/FStar#3286 +val private_keys_types_to_string: (list (map_elem DY.Lib.State.PrivateKeys.private_key_key DY.Lib.State.PrivateKeys.private_key_value #_)) -> string let rec private_keys_types_to_string m = match m with | [] -> "" | hd :: tl -> ( (private_keys_types_to_string tl) ^ - Printf.sprintf "%s = (%s)," (private_key_type_to_string hd.key) (bytes_to_string hd.value.private_key) + Printf.sprintf "%s = (%s)," (private_key_type_to_string hd.key.ty) (bytes_to_string hd.value.private_key) ) val public_key_type_to_string: DY.Lib.State.PKI.public_key_type -> string @@ -119,7 +120,8 @@ let public_key_type_to_string t = | DY.Lib.State.PKI.PkEnc u -> "PkEnc " ^ u | DY.Lib.State.PKI.Verify u -> "Verify " ^ u -val pki_types_to_string: (list (map_elem DY.Lib.State.PKI.pki_types)) -> string +// The `#_` at the end is a workaround for FStarLang/FStar#3286 +val pki_types_to_string: (list (map_elem DY.Lib.State.PKI.pki_key DY.Lib.State.PKI.pki_value #_)) -> string let rec pki_types_to_string m = match m with | [] -> "" @@ -130,12 +132,14 @@ let rec pki_types_to_string m = val default_private_keys_state_to_string: bytes -> option string let default_private_keys_state_to_string content_bytes = - let? state = parse (map DY.Lib.State.PrivateKeys.private_keys_types) content_bytes in + // another workaround for FStarLang/FStar#3286 + let? state = parse (map DY.Lib.State.PrivateKeys.private_key_key DY.Lib.State.PrivateKeys.private_key_value #_) content_bytes in Some (Printf.sprintf "[%s]" (private_keys_types_to_string state.key_values)) val default_pki_state_to_string: bytes -> option string let default_pki_state_to_string content_bytes = - let? state = parse (map DY.Lib.State.PKI.pki_types) content_bytes in + // another workaround for FStarLang/FStar#3286 + let? state = parse (map DY.Lib.State.PKI.pki_key DY.Lib.State.PKI.pki_value #_) content_bytes in Some (Printf.sprintf "[%s]" (pki_types_to_string state.key_values)) /// Searches for a printer with the correct tag @@ -159,7 +163,7 @@ let option_to_string parse_fn elem = val state_to_string: list (string & (bytes -> option string)) -> bytes -> string let state_to_string printer_list full_content_bytes = - let full_content = parse session full_content_bytes in + let full_content = parse tagged_state full_content_bytes in match full_content with | Some ({tag; content}) -> ( let parser = find_printer printer_list tag in @@ -260,8 +264,8 @@ let trace_to_string_printers_builder message_to_string state_to_string event_to_ state_to_string = ( List.append state_to_string ( [ - (DY.Lib.State.PrivateKeys.private_keys_tag, default_private_keys_state_to_string); - (DY.Lib.State.PKI.pki_tag, default_pki_state_to_string) + (DY.Lib.State.PrivateKeys.map_types_private_keys.tag, default_private_keys_state_to_string); + (DY.Lib.State.PKI.map_types_pki.tag, default_pki_state_to_string) ] ) // User supplied functions will override the default functions because the // find printer function will choose the first match.