From 4fe16e83dc70626200966322010cb7ad2fadf09a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9ophile=20Wallez?= Date: Mon, 14 Oct 2024 11:54:32 +0200 Subject: [PATCH] feat(split fun): make `local_pred_t` dependent on `tag_set_t` (#59) * feat(split fun): make `local_pred_t` dependent on `tag_set_t` * update F* * apply Clara's remark --- .../crypto/DY.Lib.Crypto.KdfExpand.Split.fst | 22 ++-- .../crypto/DY.Lib.Crypto.SplitPredicate.fst | 19 +-- src/lib/event/DY.Lib.Event.Typed.fst | 12 +- src/lib/state/DY.Lib.State.Tagged.fst | 22 ++-- src/lib/utils/DY.Lib.SplitFunction.fst | 108 +++++++++++++----- 5 files changed, 118 insertions(+), 65 deletions(-) diff --git a/src/lib/crypto/DY.Lib.Crypto.KdfExpand.Split.fst b/src/lib/crypto/DY.Lib.Crypto.KdfExpand.Split.fst index 0877b2f..660b704 100644 --- a/src/lib/crypto/DY.Lib.Crypto.KdfExpand.Split.fst +++ b/src/lib/crypto/DY.Lib.Crypto.KdfExpand.Split.fst @@ -16,7 +16,7 @@ let split_kdf_expand_usage_get_usage_params: split_function_parameters = { Some (tag, (prk_usage, info)) ); - local_fun_t = kdf_expand_crypto_usage; + local_fun_t = mk_dependent_type kdf_expand_crypto_usage; global_fun_t = prk_usage:usage{KdfExpandKey? prk_usage} -> info:bytes -> usage; default_global_fun = (fun prk_usage info -> NoUsage); @@ -45,7 +45,7 @@ let split_kdf_expand_usage_get_label_params = { Some (tag, (prk_usage, prk_label, info)) ); - local_fun_t = kdf_expand_crypto_usage; + local_fun_t = mk_dependent_type kdf_expand_crypto_usage; global_fun_t = prk_usage:usage{KdfExpandKey? prk_usage} -> prk_label:label -> info:bytes -> label; default_global_fun = (fun prk_usage prk_label info -> prk_label); @@ -88,7 +88,7 @@ let has_kdf_expand_usage #cusgs (tag, local_invariant) = val intro_has_kdf_expand_usage_get_usage: {|crypto_usages|} -> tagged_local_invariant:(string & kdf_expand_crypto_usage) -> Lemma - (requires has_local_fun split_kdf_expand_usage_get_usage_params kdf_expand_usage.get_usage tagged_local_invariant) + (requires has_local_fun split_kdf_expand_usage_get_usage_params kdf_expand_usage.get_usage (mk_dependent_tagged_local_fun tagged_local_invariant)) (ensures has_kdf_expand_usage_get_usage tagged_local_invariant) let intro_has_kdf_expand_usage_get_usage #cusgs (tag, local_invariant) = introduce @@ -107,7 +107,7 @@ let intro_has_kdf_expand_usage_get_usage #cusgs (tag, local_invariant) = val intro_has_kdf_expand_usage_get_label: {|crypto_usages|} -> tagged_local_invariant:(string & kdf_expand_crypto_usage) -> Lemma - (requires has_local_fun split_kdf_expand_usage_get_label_params kdf_expand_usage.get_label tagged_local_invariant) + (requires has_local_fun split_kdf_expand_usage_get_label_params kdf_expand_usage.get_label (mk_dependent_tagged_local_fun tagged_local_invariant)) (ensures has_kdf_expand_usage_get_label tagged_local_invariant) let intro_has_kdf_expand_usage_get_label #cusgs (tag, local_invariant) = introduce @@ -130,14 +130,14 @@ val mk_global_kdf_expand_usage_get_usage: prk_usage:usage{KdfExpandKey? prk_usage} -> info:bytes -> usage let mk_global_kdf_expand_usage_get_usage tagged_local_invariants = - mk_global_fun (split_kdf_expand_usage_get_usage_params) tagged_local_invariants + mk_global_fun (split_kdf_expand_usage_get_usage_params) (mk_dependent_tagged_local_funs tagged_local_invariants) val mk_global_kdf_expand_usage_get_label: list (string & kdf_expand_crypto_usage) -> prk_usage:usage{KdfExpandKey? prk_usage} -> prk_label:label -> info:bytes -> label let mk_global_kdf_expand_usage_get_label tagged_local_invariants = - mk_global_fun (split_kdf_expand_usage_get_label_params) tagged_local_invariants + mk_global_fun (split_kdf_expand_usage_get_label_params) (mk_dependent_tagged_local_funs tagged_local_invariants) val mk_global_kdf_expand_usage_get_label_lemma: tagged_local_invariants:list (string & kdf_expand_crypto_usage) -> @@ -145,8 +145,8 @@ val mk_global_kdf_expand_usage_get_label_lemma: prk_usage:usage{KdfExpandKey? prk_usage} -> prk_label:label -> info:bytes -> Lemma ((mk_global_kdf_expand_usage_get_label tagged_local_invariants prk_usage prk_label info) `can_flow tr` prk_label) let mk_global_kdf_expand_usage_get_label_lemma tagged_local_invariants tr prk_usage prk_label info = - mk_global_fun_eq split_kdf_expand_usage_get_label_params tagged_local_invariants (prk_usage, prk_label, info); - introduce forall tagged_local_invariants. split_kdf_expand_usage_get_label_params.apply_local_fun tagged_local_invariants (prk_usage, prk_label, info) `can_flow tr` prk_label with ( + mk_global_fun_eq split_kdf_expand_usage_get_label_params (mk_dependent_tagged_local_funs tagged_local_invariants) (prk_usage, prk_label, info); + introduce forall tag_set tagged_local_invariants. split_kdf_expand_usage_get_label_params.apply_local_fun #tag_set tagged_local_invariants (prk_usage, prk_label, info) `can_flow tr` prk_label with ( tagged_local_invariants.get_label_lemma tr prk_usage prk_label info ) @@ -168,7 +168,9 @@ val mk_kdf_expand_usage_correct: let mk_kdf_expand_usage_correct #cusgs tagged_local_invariants = no_repeats_p_implies_for_all_pairsP_unequal (List.Tot.map fst tagged_local_invariants); for_allP_eq has_kdf_expand_usage tagged_local_invariants; - FStar.Classical.forall_intro_2 (FStar.Classical.move_requires_2 (mk_global_fun_correct split_kdf_expand_usage_get_usage_params tagged_local_invariants)); - FStar.Classical.forall_intro_2 (FStar.Classical.move_requires_2 (mk_global_fun_correct split_kdf_expand_usage_get_label_params tagged_local_invariants)); + map_dfst_mk_dependent_tagged_local_funs tagged_local_invariants; + FStar.Classical.forall_intro_2 (memP_mk_dependent_tagged_local_funs tagged_local_invariants); + FStar.Classical.forall_intro_2 (FStar.Classical.move_requires_2 (mk_global_fun_correct split_kdf_expand_usage_get_usage_params (mk_dependent_tagged_local_funs tagged_local_invariants))); + FStar.Classical.forall_intro_2 (FStar.Classical.move_requires_2 (mk_global_fun_correct split_kdf_expand_usage_get_label_params (mk_dependent_tagged_local_funs tagged_local_invariants))); FStar.Classical.forall_intro (FStar.Classical.move_requires intro_has_kdf_expand_usage_get_usage); FStar.Classical.forall_intro (FStar.Classical.move_requires intro_has_kdf_expand_usage_get_label) diff --git a/src/lib/crypto/DY.Lib.Crypto.SplitPredicate.fst b/src/lib/crypto/DY.Lib.Crypto.SplitPredicate.fst index 7221345..f45fdae 100644 --- a/src/lib/crypto/DY.Lib.Crypto.SplitPredicate.fst +++ b/src/lib/crypto/DY.Lib.Crypto.SplitPredicate.fst @@ -55,12 +55,12 @@ let split_crypto_predicate_parameters_to_split_function_parameters (params:split Some (tag, (tr, key, data)) ); - local_fun_t = params.local_pred_t; + local_fun_t = mk_dependent_type params.local_pred_t; global_fun_t = params.global_pred_t; default_global_fun = params.mk_global_pred always_false; - apply_local_fun = params.apply_local_pred; + apply_local_fun = (fun #tag_set -> params.apply_local_pred); apply_global_fun = params.apply_global_pred; mk_global_fun = params.mk_global_pred; apply_mk_global_fun = params.apply_mk_global_pred; @@ -71,7 +71,7 @@ val has_local_crypto_predicate: params.global_pred_t -> (string & params.local_pred_t) -> prop let has_local_crypto_predicate params global_pred (tag, local_pred) = - has_local_fun (split_crypto_predicate_parameters_to_split_function_parameters params) global_pred (tag, local_pred) + has_local_fun (split_crypto_predicate_parameters_to_split_function_parameters params) global_pred (|tag, local_pred|) val has_local_crypto_predicate_elim: params:split_crypto_predicate_parameters -> @@ -90,7 +90,7 @@ val mk_global_crypto_predicate: list (string & params.local_pred_t) -> params.global_pred_t let mk_global_crypto_predicate params tagged_local_preds = - mk_global_fun (split_crypto_predicate_parameters_to_split_function_parameters params) tagged_local_preds + mk_global_fun (split_crypto_predicate_parameters_to_split_function_parameters params) (mk_dependent_tagged_local_funs tagged_local_preds) val mk_global_crypto_predicate_later: params:split_crypto_predicate_parameters -> @@ -107,9 +107,9 @@ val mk_global_crypto_predicate_later: let mk_global_crypto_predicate_later params tagged_local_preds tr1 tr2 key data = let fparams = split_crypto_predicate_parameters_to_split_function_parameters params in params.apply_mk_global_pred always_false (tr1, key, data); - mk_global_fun_eq fparams tagged_local_preds (tr1, key, data); - mk_global_fun_eq fparams tagged_local_preds (tr2, key, data); - introduce forall lpred. fparams.apply_local_fun lpred (tr1, key, data) ==> fparams.apply_local_fun lpred (tr2, key, data) with ( + mk_global_fun_eq fparams (mk_dependent_tagged_local_funs tagged_local_preds) (tr1, key, data); + mk_global_fun_eq fparams (mk_dependent_tagged_local_funs tagged_local_preds) (tr2, key, data); + introduce forall tag_set lpred. fparams.apply_local_fun lpred (tr1, key, data) ==> fparams.apply_local_fun #tag_set lpred (tr2, key, data) with ( introduce _ ==> _ with _. params.apply_local_pred_later lpred tr1 tr2 key data ) @@ -125,5 +125,6 @@ val mk_global_crypto_predicate_correct: (ensures has_local_crypto_predicate params (mk_global_crypto_predicate params tagged_local_preds) (tag, local_pred)) let mk_global_crypto_predicate_correct params tagged_local_preds tag local_pred = no_repeats_p_implies_for_all_pairsP_unequal (List.Tot.map fst tagged_local_preds); - mk_global_fun_correct (split_crypto_predicate_parameters_to_split_function_parameters params) tagged_local_preds tag local_pred - + map_dfst_mk_dependent_tagged_local_funs tagged_local_preds; + memP_mk_dependent_tagged_local_funs tagged_local_preds tag local_pred; + mk_global_fun_correct (split_crypto_predicate_parameters_to_split_function_parameters params) (mk_dependent_tagged_local_funs tagged_local_preds) tag local_pred diff --git a/src/lib/event/DY.Lib.Event.Typed.fst b/src/lib/event/DY.Lib.Event.Typed.fst index bb00784..6fa337e 100644 --- a/src/lib/event/DY.Lib.Event.Typed.fst +++ b/src/lib/event/DY.Lib.Event.Typed.fst @@ -52,7 +52,7 @@ let split_event_pred_params: split_function_parameters = { Some (tag, (tr, prin, content)) )); - local_fun_t = trace -> principal -> bytes -> prop; + local_fun_t = mk_dependent_type (trace -> principal -> bytes -> prop); global_fun_t = trace -> principal -> string -> bytes -> prop; default_global_fun = (fun tr prin tag content -> False); @@ -69,7 +69,7 @@ let split_event_pred_params: split_function_parameters = { apply_mk_global_fun = (fun spred x -> ()); } -type compiled_event_predicate = split_event_pred_params.local_fun_t +type compiled_event_predicate = trace -> principal -> bytes -> prop val compile_event_pred: #a:Type0 -> {|event a|} -> @@ -84,7 +84,7 @@ let compile_event_pred #a #ev epred tr prin content_bytes = val has_compiled_event_pred: {|protocol_invariants|} -> (string & compiled_event_predicate) -> prop let has_compiled_event_pred #invs (tag, epred) = - has_local_fun split_event_pred_params event_pred (tag, epred) + has_local_fun split_event_pred_params event_pred (|tag, epred|) val has_event_pred: #a:Type0 -> {|event a|} -> @@ -96,7 +96,7 @@ let has_event_pred #a #ev #invs epred = val mk_event_pred: {|crypto_invariants|} -> list (string & compiled_event_predicate) -> trace -> principal -> string -> bytes -> prop let mk_event_pred #cinvs tagged_local_preds = - mk_global_fun split_event_pred_params tagged_local_preds + mk_global_fun split_event_pred_params (mk_dependent_tagged_local_funs tagged_local_preds) val mk_event_pred_correct: {|protocol_invariants|} -> tagged_local_preds:list (string & compiled_event_predicate) -> Lemma (requires @@ -108,7 +108,9 @@ let mk_event_pred_correct #invs tagged_local_preds = reveal_opaque (`%has_compiled_event_pred) (has_compiled_event_pred); no_repeats_p_implies_for_all_pairsP_unequal (List.Tot.map fst tagged_local_preds); for_allP_eq has_compiled_event_pred tagged_local_preds; - FStar.Classical.forall_intro_2 (FStar.Classical.move_requires_2 (mk_global_fun_correct split_event_pred_params tagged_local_preds)) + map_dfst_mk_dependent_tagged_local_funs tagged_local_preds; + FStar.Classical.forall_intro_2 (memP_mk_dependent_tagged_local_funs tagged_local_preds); + FStar.Classical.forall_intro_2 (FStar.Classical.move_requires_2 (mk_global_fun_correct split_event_pred_params (mk_dependent_tagged_local_funs tagged_local_preds))) (*** Monadic functions ***) diff --git a/src/lib/state/DY.Lib.State.Tagged.fst b/src/lib/state/DY.Lib.State.Tagged.fst index 6ef83a1..54cb0c3 100644 --- a/src/lib/state/DY.Lib.State.Tagged.fst +++ b/src/lib/state/DY.Lib.State.Tagged.fst @@ -53,7 +53,7 @@ let split_local_bytes_state_predicate_params {|crypto_invariants|} : split_funct | None -> None )); - local_fun_t = local_bytes_state_predicate; + local_fun_t = mk_dependent_type local_bytes_state_predicate; global_fun_t = trace -> principal -> state_id -> bytes -> prop; default_global_fun = (fun tr prin sess_id sess_content -> False); @@ -73,13 +73,13 @@ let split_local_bytes_state_predicate_params {|crypto_invariants|} : split_funct [@@ "opaque_to_smt"] val has_local_bytes_state_predicate: {|protocol_invariants|} -> (string & local_bytes_state_predicate) -> prop let has_local_bytes_state_predicate #invs (tag, spred) = - has_local_fun split_local_bytes_state_predicate_params state_pred.pred (tag, spred) + has_local_fun split_local_bytes_state_predicate_params state_pred.pred (|tag, spred|) (*** Global tagged state predicate builder ***) val mk_global_local_bytes_state_predicate: {|crypto_invariants|} -> list (string & local_bytes_state_predicate) -> trace -> principal -> state_id -> bytes -> prop let mk_global_local_bytes_state_predicate #cinvs tagged_local_preds = - mk_global_fun split_local_bytes_state_predicate_params tagged_local_preds + mk_global_fun split_local_bytes_state_predicate_params (mk_dependent_tagged_local_funs tagged_local_preds) #push-options "--ifuel 2" // to deconstruct nested tuples val mk_global_local_bytes_state_predicate_later: @@ -88,9 +88,9 @@ val mk_global_local_bytes_state_predicate_later: (requires mk_global_local_bytes_state_predicate tagged_local_preds tr1 prin sess_id full_content /\ tr1 <$ tr2) (ensures mk_global_local_bytes_state_predicate tagged_local_preds tr2 prin sess_id full_content) let mk_global_local_bytes_state_predicate_later #cinvs tagged_local_preds tr1 tr2 prin sess_id full_content = - mk_global_fun_eq split_local_bytes_state_predicate_params tagged_local_preds (tr1, prin, sess_id, full_content); - mk_global_fun_eq split_local_bytes_state_predicate_params tagged_local_preds (tr2, prin, sess_id, full_content); - introduce forall lpred content. split_local_bytes_state_predicate_params.apply_local_fun lpred (tr1, prin, sess_id, content) ==> split_local_bytes_state_predicate_params.apply_local_fun lpred (tr2, prin, sess_id, content) with ( + mk_global_fun_eq split_local_bytes_state_predicate_params (mk_dependent_tagged_local_funs tagged_local_preds) (tr1, prin, sess_id, full_content); + mk_global_fun_eq split_local_bytes_state_predicate_params (mk_dependent_tagged_local_funs tagged_local_preds) (tr2, prin, sess_id, full_content); + introduce forall tag_set lpred content. split_local_bytes_state_predicate_params.apply_local_fun #tag_set lpred (tr1, prin, sess_id, content) ==> split_local_bytes_state_predicate_params.apply_local_fun lpred (tr2, prin, sess_id, content) with ( introduce _ ==> _ with _. lpred.pred_later tr1 tr2 prin sess_id content ) #pop-options @@ -102,11 +102,11 @@ val mk_global_local_bytes_state_predicate_knowable: (requires mk_global_local_bytes_state_predicate tagged_local_preds tr prin sess_id full_content) (ensures is_knowable_by (principal_state_label prin sess_id) tr full_content) let mk_global_local_bytes_state_predicate_knowable #cinvs tagged_local_preds tr prin sess_id full_content = - mk_global_fun_eq split_local_bytes_state_predicate_params tagged_local_preds (tr, prin, sess_id, full_content); + mk_global_fun_eq split_local_bytes_state_predicate_params (mk_dependent_tagged_local_funs tagged_local_preds) (tr, prin, sess_id, full_content); match split_local_bytes_state_predicate_params.decode_tagged_data (tr, prin, sess_id, full_content) with | Some (tag, (_, _, _, content)) -> ( - match find_local_fun split_local_bytes_state_predicate_params tagged_local_preds tag with - | Some lpred -> ( + match find_local_fun split_local_bytes_state_predicate_params (mk_dependent_tagged_local_funs tagged_local_preds) tag with + | Some (|_, lpred|) -> ( lpred.pred_knowable tr prin sess_id content; serialize_parse_inv_lemma tagged_state full_content; serialize_wf_lemma tagged_state (is_knowable_by (principal_state_label prin sess_id) tr) ({tag; content}) @@ -133,7 +133,9 @@ let mk_state_pred_correct #invs tagged_local_preds = reveal_opaque (`%has_local_bytes_state_predicate) (has_local_bytes_state_predicate); no_repeats_p_implies_for_all_pairsP_unequal (List.Tot.map fst tagged_local_preds); for_allP_eq has_local_bytes_state_predicate tagged_local_preds; - FStar.Classical.forall_intro_2 (FStar.Classical.move_requires_2 (mk_global_fun_correct split_local_bytes_state_predicate_params tagged_local_preds)) + map_dfst_mk_dependent_tagged_local_funs tagged_local_preds; + FStar.Classical.forall_intro_2 (FStar.Classical.move_requires_2 (mk_global_fun_correct split_local_bytes_state_predicate_params (mk_dependent_tagged_local_funs tagged_local_preds))); + FStar.Classical.forall_intro_2 (memP_mk_dependent_tagged_local_funs tagged_local_preds) (*** Predicates on trace ***) diff --git a/src/lib/utils/DY.Lib.SplitFunction.fst b/src/lib/utils/DY.Lib.SplitFunction.fst index 7a3bbad..64182ac 100644 --- a/src/lib/utils/DY.Lib.SplitFunction.fst +++ b/src/lib/utils/DY.Lib.SplitFunction.fst @@ -94,13 +94,13 @@ noeq type split_function_parameters = { decode_tagged_data: tagged_data_t -> option (tag_t & raw_data_t); // Types for the local functions and the global function - local_fun_t: Type; + local_fun_t: tag_set_t -> Type; global_fun_t: Type; default_global_fun: global_fun_t; // Apply a local function to its input - apply_local_fun: local_fun_t -> raw_data_t -> output_t; + apply_local_fun: #tag_set:tag_set_t -> local_fun_t tag_set -> raw_data_t -> output_t; // Apply the global function to its input apply_global_fun: global_fun_t -> tagged_data_t -> output_t; // Create a global function @@ -113,8 +113,8 @@ noeq type split_function_parameters = { /// Do a global function contain some given local function with some set of tags? /// This will be a crucial precondition for the correctness theorem `local_eq_global_lemma`. -val has_local_fun: params:split_function_parameters -> params.global_fun_t -> (params.tag_set_t & params.local_fun_t) -> prop -let has_local_fun params global_fun (tag_set, local_fun) = +val has_local_fun: params:split_function_parameters -> params.global_fun_t -> (dtuple2 params.tag_set_t params.local_fun_t) -> prop +let has_local_fun params global_fun (|tag_set, local_fun|) = forall tagged_data. match params.decode_tagged_data tagged_data with | Some (tag, raw_data) -> @@ -125,10 +125,10 @@ let has_local_fun params global_fun (tag_set, local_fun) = val has_local_fun_elim: params:split_function_parameters -> - global_fun:params.global_fun_t -> tag_set:params.tag_set_t -> local_fun:params.local_fun_t -> + global_fun:params.global_fun_t -> tag_set:params.tag_set_t -> local_fun:params.local_fun_t tag_set -> tagged_data:params.tagged_data_t -> Lemma - (requires has_local_fun params global_fun (tag_set, local_fun)) + (requires has_local_fun params global_fun (|tag_set, local_fun|)) (ensures ( match params.decode_tagged_data tagged_data with | Some (tag, raw_data) -> @@ -141,23 +141,24 @@ let has_local_fun_elim params global_fun tag_set local_fun tagged_data = () /// In practice, only one tag set may contain `tag` because tag sets are mutually disjoint /// (c.f. precondition of `mk_global_fun_correct`). /// In that case, this function returns the *unique* local function associated with a tag set containing `tag`. -val find_local_fun: params:split_function_parameters -> list (params.tag_set_t & params.local_fun_t) -> params.tag_t -> option params.local_fun_t +val find_local_fun: params:split_function_parameters -> l:list (dtuple2 params.tag_set_t params.local_fun_t) -> params.tag_t -> Tot (option (dtuple2 params.tag_set_t params.local_fun_t)) +(decreases List.Tot.length l) let rec find_local_fun params tagged_local_funs tag = match tagged_local_funs with | [] -> None - | (h_tag_set, h_local_fun)::t_tagged_local_funs -> ( + | (|h_tag_set, h_local_fun|)::t_tagged_local_funs -> ( if tag `params.tag_belong_to` h_tag_set then - Some h_local_fun + Some (|h_tag_set, h_local_fun|) else find_local_fun params t_tagged_local_funs tag ) -val mk_global_fun_aux: params:split_function_parameters -> list (params.tag_set_t & params.local_fun_t) -> params.tagged_data_t -> params.output_t +val mk_global_fun_aux: params:split_function_parameters -> list (dtuple2 params.tag_set_t params.local_fun_t) -> params.tagged_data_t -> params.output_t let mk_global_fun_aux params tagged_local_funs tagged_data = match params.decode_tagged_data tagged_data with | Some (tag_set, raw_data) -> ( match find_local_fun params tagged_local_funs tag_set with - | Some tagged_local_fun -> params.apply_local_fun tagged_local_fun raw_data + | Some (|_, tagged_local_fun|) -> params.apply_local_fun tagged_local_fun raw_data | None -> params.apply_global_fun params.default_global_fun tagged_data ) | None -> params.apply_global_fun params.default_global_fun tagged_data @@ -165,7 +166,7 @@ let mk_global_fun_aux params tagged_local_funs tagged_data = /// Given a list of tags and local functions, create the global function. [@@"opaque_to_smt"] -val mk_global_fun: params:split_function_parameters -> list (params.tag_set_t & params.local_fun_t) -> params.global_fun_t +val mk_global_fun: params:split_function_parameters -> list (dtuple2 params.tag_set_t params.local_fun_t) -> params.global_fun_t let mk_global_fun params tagged_local_funs = params.mk_global_fun (mk_global_fun_aux params tagged_local_funs) @@ -189,37 +190,37 @@ let rec for_all_pairsP #a disj l = | h::t -> (for_allP (disj h) t) /\ for_all_pairsP disj t val mk_global_fun_correct_aux: - params:split_function_parameters -> tagged_local_funs:list (params.tag_set_t & params.local_fun_t) -> tag_set:params.tag_set_t -> local_fun:params.local_fun_t -> tag:params.tag_t -> + params:split_function_parameters -> tagged_local_funs:list (dtuple2 params.tag_set_t params.local_fun_t) -> tag_set:params.tag_set_t -> local_fun:params.local_fun_t tag_set -> tag:params.tag_t -> Lemma (requires - for_all_pairsP (params.is_disjoint) (List.Tot.map fst tagged_local_funs) /\ + for_all_pairsP (params.is_disjoint) (List.Tot.map dfst tagged_local_funs) /\ tag `params.tag_belong_to` tag_set /\ - List.Tot.memP (tag_set, local_fun) tagged_local_funs + List.Tot.memP (|tag_set, local_fun|) tagged_local_funs ) - (ensures find_local_fun params tagged_local_funs tag == Some local_fun) -let rec mk_global_fun_correct_aux params tagged_local_funs tag_set tagged_local_fun tag = + (ensures find_local_fun params tagged_local_funs tag == Some (|tag_set, local_fun|)) +let rec mk_global_fun_correct_aux params tagged_local_funs tag_set local_fun tag = match tagged_local_funs with | [] -> () - | (h_tag_set, h_tagged_local_fun)::t_tagged_local_funs -> ( + | (|h_tag_set, h_tagged_local_fun|)::t_tagged_local_funs -> ( if tag `params.tag_belong_to` h_tag_set then ( - introduce (List.Tot.memP (tag_set, tagged_local_fun) t_tagged_local_funs) ==> False with _. ( - for_allP_eq (params.is_disjoint h_tag_set) (List.Tot.map fst t_tagged_local_funs); - memP_map fst t_tagged_local_funs (tag_set, tagged_local_fun); + introduce (List.Tot.memP (|tag_set, local_fun|) t_tagged_local_funs) ==> False with _. ( + for_allP_eq (params.is_disjoint h_tag_set) (List.Tot.map dfst t_tagged_local_funs); + memP_map dfst t_tagged_local_funs (|tag_set, local_fun|); params.cant_belong_to_disjoint_sets tag h_tag_set tag_set ) ) else ( - mk_global_fun_correct_aux params t_tagged_local_funs tag_set tagged_local_fun tag + mk_global_fun_correct_aux params t_tagged_local_funs tag_set local_fun tag ) ) val mk_global_fun_correct: - params:split_function_parameters -> tagged_local_funs:list (params.tag_set_t & params.local_fun_t) -> tag_set:params.tag_set_t -> local_fun:params.local_fun_t -> + params:split_function_parameters -> tagged_local_funs:list (dtuple2 params.tag_set_t params.local_fun_t) -> tag_set:params.tag_set_t -> local_fun:params.local_fun_t tag_set -> Lemma (requires - for_all_pairsP (params.is_disjoint) (List.Tot.map fst tagged_local_funs) /\ - List.Tot.memP (tag_set, local_fun) tagged_local_funs + for_all_pairsP (params.is_disjoint) (List.Tot.map dfst tagged_local_funs) /\ + List.Tot.memP (|tag_set, local_fun|) tagged_local_funs ) - (ensures has_local_fun params (mk_global_fun params tagged_local_funs) (tag_set, local_fun)) + (ensures has_local_fun params (mk_global_fun params tagged_local_funs) (|tag_set, local_fun|)) let mk_global_fun_correct params tagged_local_funs tag_set local_fun = reveal_opaque (`%mk_global_fun) (mk_global_fun); introduce @@ -245,14 +246,14 @@ let mk_global_fun_correct params tagged_local_funs tag_set local_fun = /// (e.g. the function keep the same output when the trace grows.) val mk_global_fun_eq: - params:split_function_parameters -> tagged_local_funs:list (params.tag_set_t & params.local_fun_t) -> + params:split_function_parameters -> tagged_local_funs:list (dtuple2 params.tag_set_t params.local_fun_t) -> tagged_data:params.tagged_data_t -> Lemma ( params.apply_global_fun (mk_global_fun params tagged_local_funs) tagged_data == ( match params.decode_tagged_data tagged_data with | Some (tag, raw_data) -> ( match find_local_fun params tagged_local_funs tag with - | Some tagged_local_fun -> params.apply_local_fun tagged_local_fun raw_data + | Some (|_, tagged_local_fun|) -> params.apply_local_fun tagged_local_fun raw_data | None -> params.apply_global_fun params.default_global_fun tagged_data ) | None -> params.apply_global_fun params.default_global_fun tagged_data @@ -268,13 +269,13 @@ let mk_global_fun_eq params tagged_local_funs tagged_data = val local_eq_global_lemma: params:split_function_parameters -> - global_fun:params.global_fun_t -> tag_set:params.tag_set_t -> local_fun:params.local_fun_t -> + global_fun:params.global_fun_t -> tag_set:params.tag_set_t -> local_fun:params.local_fun_t tag_set -> tagged_data:params.tagged_data_t -> tag:params.tag_t -> raw_data:params.raw_data_t -> Lemma (requires params.decode_tagged_data tagged_data == Some (tag, raw_data) /\ tag `params.tag_belong_to` tag_set /\ - has_local_fun params global_fun (tag_set, local_fun) + has_local_fun params global_fun (|tag_set, local_fun|) ) (ensures params.apply_global_fun global_fun tagged_data == params.apply_local_fun local_fun raw_data) let local_eq_global_lemma params global_fun tag_set tagged_local_fun tagged_data tag raw_data = () @@ -322,7 +323,7 @@ let singleton_split_function_parameters (a:eqtype): split_function_parameters = output_t = unit; decode_tagged_data = (fun x -> None); - local_fun_t = unit; + local_fun_t = (fun _ -> unit); global_fun_t = unit; default_global_fun = (); @@ -332,3 +333,48 @@ let singleton_split_function_parameters (a:eqtype): split_function_parameters = mk_global_fun = (fun bare -> ()); apply_mk_global_fun = (fun bare x -> ()); } + +/// When the `local_fun_t` doesn't depend on a `tag_set_t`, +/// the following functions and lemmas are handy. + +val mk_dependent_type: + #tag_set_t:Type -> Type u#a -> + tag_set_t -> Type u#a +let mk_dependent_type #tag_set_t local_fun_t _ = local_fun_t + +val mk_dependent_tagged_local_fun: + #tag_set_t:Type -> #local_fun_t:Type -> + tag_set_t & local_fun_t -> + dtuple2 tag_set_t (mk_dependent_type local_fun_t) +let mk_dependent_tagged_local_fun #tag_set_t #local_fun_t (tag_set, local_fun) = + (|tag_set, local_fun|) + +val mk_dependent_tagged_local_funs: + #tag_set_t:Type -> #local_fun_t:Type -> + list (tag_set_t & local_fun_t) -> + list (dtuple2 tag_set_t (mk_dependent_type local_fun_t)) +let mk_dependent_tagged_local_funs #tag_set_t #local_fun_t l = + List.Tot.map mk_dependent_tagged_local_fun l + +val map_dfst_mk_dependent_tagged_local_funs: + #tag_set_t:Type -> #local_fun_t:Type -> + tagged_local_funs:list (tag_set_t & local_fun_t) -> + Lemma (List.Tot.map fst tagged_local_funs == List.Tot.map dfst (mk_dependent_tagged_local_funs tagged_local_funs)) +let rec map_dfst_mk_dependent_tagged_local_funs #tag_set_t #local_fun_t l = + match l with + | [] -> () + | (x,y)::t -> map_dfst_mk_dependent_tagged_local_funs t + +val memP_mk_dependent_tagged_local_funs: + #tag_set_t:Type -> #local_fun_t:Type -> + tagged_local_funs:list (tag_set_t & local_fun_t) -> + tag_set:tag_set_t -> local_fun:local_fun_t -> + Lemma ( + List.Tot.Base.memP (tag_set, local_fun) tagged_local_funs ==> + List.Tot.Base.memP (|tag_set, local_fun|) (mk_dependent_tagged_local_funs tagged_local_funs) + ) +let rec memP_mk_dependent_tagged_local_funs #tag_set_t #local_fun_t l tag_set local_fun = + match l with + | [] -> () + | (x,y)::t -> + memP_mk_dependent_tagged_local_funs t tag_set local_fun