Skip to content

Commit

Permalink
feat: add typeclasses for easier use of State.Typed (#19)
Browse files Browse the repository at this point in the history
* feat: add typeclasses for easier use of State.Typed

* cleanup: improve type and function names in DY.Lib.State.Tagged

* cleanup: improve `map_types` instance names

* cleanup: use a dedicated type for PrivateKey's map keys
  • Loading branch information
TWal authored May 7, 2024
1 parent 6fe7c9d commit 452210f
Show file tree
Hide file tree
Showing 10 changed files with 321 additions and 270 deletions.
2 changes: 1 addition & 1 deletion examples/nsl_pk/DY.Example.NSL.Debug.Printing.fst
Original file line number Diff line number Diff line change
Expand Up @@ -81,5 +81,5 @@ val get_nsl_trace_to_string_printers: bytes -> bytes -> trace_to_string_printers
let get_nsl_trace_to_string_printers priv_key_alice priv_key_bob =
trace_to_string_printers_builder
(message_to_string priv_key_alice priv_key_bob)
[(nsl_session_tag, session_to_string)]
[(local_state_nsl_session.tag, session_to_string)]
[(event_nsl_event.tag, event_to_string)]
26 changes: 13 additions & 13 deletions examples/nsl_pk/DY.Example.NSL.Protocol.Stateful.Proof.fst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ open DY.Example.NSL.Protocol.Stateful

/// The (local) state predicate.

let nsl_session_pred: typed_session_pred nsl_session = {
let nsl_session_pred: local_state_predicate nsl_session = {
pred = (fun tr prin sess_id st ->
match st with
| InitiatorSentMsg1 bob n_a -> (
Expand Down Expand Up @@ -82,9 +82,9 @@ let nsl_event_pred: event_predicate nsl_event =
/// List of all local state predicates.

let all_sessions = [
(pki_tag, typed_session_pred_to_session_pred (map_session_invariant pki_pred));
(private_keys_tag, typed_session_pred_to_session_pred (map_session_invariant private_keys_pred));
(nsl_session_tag, typed_session_pred_to_session_pred nsl_session_pred);
pki_tag_and_invariant;
private_keys_tag_and_invariant;
(local_state_nsl_session.tag, local_state_predicate_to_local_bytes_state_predicate nsl_session_pred);
]

/// List of all local event predicates.
Expand All @@ -107,19 +107,19 @@ instance nsl_protocol_invs: protocol_invariants = {

/// Lemmas that the global state predicate contains all the local ones

val all_sessions_has_all_sessions: unit -> Lemma (norm [delta_only [`%all_sessions; `%for_allP]; iota; zeta] (for_allP (has_session_pred nsl_protocol_invs) all_sessions))
val all_sessions_has_all_sessions: unit -> Lemma (norm [delta_only [`%all_sessions; `%for_allP]; iota; zeta] (for_allP (has_local_bytes_state_predicate nsl_protocol_invs) all_sessions))
let all_sessions_has_all_sessions () =
assert_norm(List.Tot.no_repeats_p (List.Tot.map fst (all_sessions)));
mk_global_session_pred_correct nsl_protocol_invs all_sessions;
norm_spec [delta_only [`%all_sessions; `%for_allP]; iota; zeta] (for_allP (has_session_pred nsl_protocol_invs) all_sessions)
mk_global_local_bytes_state_predicate_correct nsl_protocol_invs all_sessions;
norm_spec [delta_only [`%all_sessions; `%for_allP]; iota; zeta] (for_allP (has_local_bytes_state_predicate nsl_protocol_invs) all_sessions)

val full_nsl_session_pred_has_pki_invariant: squash (has_pki_invariant nsl_protocol_invs)
let full_nsl_session_pred_has_pki_invariant = all_sessions_has_all_sessions ()

val full_nsl_session_pred_has_private_keys_invariant: squash (has_private_keys_invariant nsl_protocol_invs)
let full_nsl_session_pred_has_private_keys_invariant = all_sessions_has_all_sessions ()

val full_nsl_session_pred_has_nsl_invariant: squash (has_typed_session_pred nsl_protocol_invs (nsl_session_tag, nsl_session_pred))
val full_nsl_session_pred_has_nsl_invariant: squash (has_local_state_predicate nsl_protocol_invs nsl_session_pred)
let full_nsl_session_pred_has_nsl_invariant = all_sessions_has_all_sessions ()

/// Lemmas that the global event predicate contains all the local ones
Expand Down Expand Up @@ -159,7 +159,7 @@ val send_msg1_proof:
trace_invariant tr_out
))
let send_msg1_proof tr global_sess_id alice sess_id =
match get_typed_state #nsl_session nsl_session_tag alice sess_id tr with
match get_state alice sess_id tr with
| (Some (InitiatorSentMsg1 bob n_a), tr) -> (
match get_public_key alice global_sess_id.pki (PkEnc "NSL.PublicKey") bob tr with
| (None, tr) -> ()
Expand Down Expand Up @@ -200,7 +200,7 @@ val send_msg2_proof:
trace_invariant tr_out
))
let send_msg2_proof tr global_sess_id bob sess_id =
match get_typed_state nsl_session_tag bob sess_id tr with
match get_state bob sess_id tr with
| (Some (ResponderSentMsg2 alice n_a n_b), tr) -> (
match get_public_key bob global_sess_id.pki (PkEnc "NSL.PublicKey") alice tr with
| (None, tr) -> ()
Expand All @@ -227,7 +227,7 @@ let prepare_msg3_proof tr global_sess_id alice sess_id msg_id =
match get_private_key alice global_sess_id.private_keys (PkDec "NSL.PublicKey") tr with
| (None, tr) -> ()
| (Some sk_a, tr) -> (
match get_typed_state nsl_session_tag alice sess_id tr with
match get_state alice sess_id tr with
| (Some (InitiatorSentMsg1 bob n_a), tr) -> (
decode_message2_proof tr alice bob msg sk_a n_a
)
Expand All @@ -245,7 +245,7 @@ val send_msg3_proof:
trace_invariant tr_out
))
let send_msg3_proof tr global_sess_id alice sess_id =
match get_typed_state nsl_session_tag alice sess_id tr with
match get_state alice sess_id tr with
| (Some (InitiatorSentMsg3 bob n_a n_b), tr) -> (
match get_public_key alice global_sess_id.pki (PkEnc "NSL.PublicKey") bob tr with
| (None, tr) -> ()
Expand Down Expand Up @@ -289,7 +289,7 @@ let prepare_msg4 tr global_sess_id bob sess_id msg_id =
match get_private_key bob global_sess_id.private_keys (PkDec "NSL.PublicKey") tr with
| (None, tr) -> ()
| (Some sk_b, tr) -> (
match get_typed_state nsl_session_tag bob sess_id tr with
match get_state bob sess_id tr with
| (Some (ResponderSentMsg2 alice n_a n_b), tr) -> (
decode_message3_proof tr alice bob msg sk_b n_b;

Expand Down
24 changes: 13 additions & 11 deletions examples/nsl_pk/DY.Example.NSL.Protocol.Stateful.fst
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ type nsl_session =
instance nsl_session_parseable_serializeable: parseable_serializeable bytes nsl_session
= mk_parseable_serializeable ps_nsl_session

val nsl_session_tag: string
let nsl_session_tag = "NSL.Session"
instance local_state_nsl_session: local_state nsl_session = {
tag = "NSL.Session";
format = nsl_session_parseable_serializeable;
}

(*** Event type ***)

Expand Down Expand Up @@ -61,12 +63,12 @@ let prepare_msg1 alice bob =
let* n_a = mk_rand NoUsage (join (principal_label alice) (principal_label bob)) 32 in
trigger_event alice (Initiate1 alice bob n_a);*
let* sess_id = new_session_id alice in
set_typed_state nsl_session_tag alice sess_id (InitiatorSentMsg1 bob n_a <: nsl_session);*
set_state alice sess_id (InitiatorSentMsg1 bob n_a <: nsl_session);*
return sess_id

val send_msg1: nsl_global_sess_ids -> principal -> nat -> crypto (option nat)
let send_msg1 global_sess_id alice sess_id =
let*? st: nsl_session = get_typed_state nsl_session_tag alice sess_id in
let*? st: nsl_session = get_state alice sess_id in
match st with
| InitiatorSentMsg1 bob n_a -> (
let*? pk_b = get_public_key alice global_sess_id.pki (PkEnc "NSL.PublicKey") bob in
Expand All @@ -85,12 +87,12 @@ let prepare_msg2 global_sess_id bob msg_id =
let* n_b = mk_rand NoUsage (join (principal_label msg1.alice) (principal_label bob)) 32 in
trigger_event bob (Respond1 msg1.alice bob msg1.n_a n_b);*
let* sess_id = new_session_id bob in
set_typed_state nsl_session_tag bob sess_id (ResponderSentMsg2 msg1.alice msg1.n_a n_b <: nsl_session);*
set_state bob sess_id (ResponderSentMsg2 msg1.alice msg1.n_a n_b <: nsl_session);*
return (Some sess_id)

val send_msg2: nsl_global_sess_ids -> principal -> nat -> crypto (option nat)
let send_msg2 global_sess_id bob sess_id =
let*? st: nsl_session = get_typed_state nsl_session_tag bob sess_id in
let*? st: nsl_session = get_state bob sess_id in
match st with
| ResponderSentMsg2 alice n_a n_b -> (
let*? pk_a = get_public_key bob global_sess_id.pki (PkEnc "NSL.PublicKey") alice in
Expand All @@ -105,19 +107,19 @@ val prepare_msg3: nsl_global_sess_ids -> principal -> nat -> nat -> crypto (opti
let prepare_msg3 global_sess_id alice sess_id msg_id =
let*? msg = recv_msg msg_id in
let*? sk_a = get_private_key alice global_sess_id.private_keys (PkDec "NSL.PublicKey") in
let*? st: nsl_session = get_typed_state nsl_session_tag alice sess_id in
let*? st: nsl_session = get_state alice sess_id in
match st with
| InitiatorSentMsg1 bob n_a -> (
let*? msg2: message2 = return (decode_message2 alice bob msg sk_a n_a) in
trigger_event alice (Initiate2 alice bob n_a msg2.n_b);*
set_typed_state nsl_session_tag alice sess_id (InitiatorSentMsg3 bob n_a msg2.n_b <: nsl_session);*
set_state alice sess_id (InitiatorSentMsg3 bob n_a msg2.n_b <: nsl_session);*
return (Some ())
)
| _ -> return None

val send_msg3: nsl_global_sess_ids -> principal -> nat -> crypto (option nat)
let send_msg3 global_sess_id alice sess_id =
let*? st: nsl_session = get_typed_state nsl_session_tag alice sess_id in
let*? st: nsl_session = get_state alice sess_id in
match st with
| InitiatorSentMsg3 bob n_a n_b -> (
let*? pk_b = get_public_key alice global_sess_id.pki (PkEnc "NSL.PublicKey") bob in
Expand All @@ -132,12 +134,12 @@ val prepare_msg4: nsl_global_sess_ids -> principal -> nat -> nat -> crypto (opti
let prepare_msg4 global_sess_id bob sess_id msg_id =
let*? msg = recv_msg msg_id in
let*? sk_b = get_private_key bob global_sess_id.private_keys (PkDec "NSL.PublicKey") in
let*? st: nsl_session = get_typed_state nsl_session_tag bob sess_id in
let*? st: nsl_session = get_state bob sess_id in
match st with
| ResponderSentMsg2 alice n_a n_b -> (
let*? msg3: message3 = return (decode_message3 alice bob msg sk_b n_b) in
trigger_event bob (Respond2 alice bob n_a n_b);*
set_typed_state nsl_session_tag bob sess_id (ResponderReceivedMsg3 alice n_a n_b <: nsl_session);*
set_state bob sess_id (ResponderReceivedMsg3 alice n_a n_b <: nsl_session);*
return (Some ())
)
| _ -> return None
12 changes: 6 additions & 6 deletions examples/nsl_pk/DY.Example.NSL.SecurityProperties.fst
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ val n_a_secrecy:
(requires
attacker_knows tr n_a /\
trace_invariant tr /\ (
(exists sess_id. typed_state_was_set tr nsl_session_tag alice sess_id (InitiatorSentMsg1 bob n_a)) \/
(exists sess_id n_b. typed_state_was_set tr nsl_session_tag alice sess_id (InitiatorSentMsg3 bob n_a n_b)) \/
(exists sess_id n_b. typed_state_was_set tr nsl_session_tag bob sess_id (ResponderReceivedMsg3 alice n_a n_b))
(exists sess_id. state_was_set tr alice sess_id (InitiatorSentMsg1 bob n_a)) \/
(exists sess_id n_b. state_was_set tr alice sess_id (InitiatorSentMsg3 bob n_a n_b)) \/
(exists sess_id n_b. state_was_set tr bob sess_id (ResponderReceivedMsg3 alice n_a n_b))
)
)
(ensures is_corrupt tr (principal_label alice) \/ is_corrupt tr (principal_label bob))
Expand All @@ -80,9 +80,9 @@ val n_b_secrecy:
(requires
attacker_knows tr n_b /\
trace_invariant tr /\ (
(exists sess_id n_a. typed_state_was_set tr nsl_session_tag bob sess_id (ResponderSentMsg2 alice n_a n_b)) \/
(exists sess_id n_a. typed_state_was_set tr nsl_session_tag bob sess_id (ResponderReceivedMsg3 alice n_a n_b)) \/
(exists sess_id n_a. typed_state_was_set tr nsl_session_tag alice sess_id (InitiatorSentMsg3 bob n_a n_b))
(exists sess_id n_a. state_was_set tr bob sess_id (ResponderSentMsg2 alice n_a n_b)) \/
(exists sess_id n_a. state_was_set tr bob sess_id (ResponderReceivedMsg3 alice n_a n_b)) \/
(exists sess_id n_a. state_was_set tr alice sess_id (InitiatorSentMsg3 bob n_a n_b))
)
)
(ensures is_corrupt tr (principal_label alice) \/ is_corrupt tr (principal_label bob))
Expand Down
Loading

0 comments on commit 452210f

Please sign in to comment.