From c32bdf87e48c52450c5aeac6c8105ae1c1f1a38d Mon Sep 17 00:00:00 2001 From: Jonathan Protzenko Date: Fri, 25 Oct 2024 11:19:23 -0700 Subject: [PATCH] Make the computation of derive traits automatic via another least fixed point --- lib/AstToMiniRust.ml | 70 +++++++++++++++++++++++++++++++++++++++----- lib/MiniRust.ml | 1 + lib/PrintMiniRust.ml | 3 +- 3 files changed, 65 insertions(+), 9 deletions(-) diff --git a/lib/AstToMiniRust.ml b/lib/AstToMiniRust.ml index 3c4d6ed4..830d6662 100644 --- a/lib/AstToMiniRust.ml +++ b/lib/AstToMiniRust.ml @@ -1,6 +1,7 @@ (* Low* to Rust backend *) module LidMap = Idents.LidMap +module LidSet = Idents.LidSet (* Location information *) @@ -339,6 +340,11 @@ module DataTypeMap = Map.Make(struct let compare = compare end) +module TraitSet = Set.Make(struct + type t = MiniRust.trait + let compare = compare +end) + type env = { decls: (MiniRust.name * MiniRust.typ) LidMap.t; global_scope: NameSet.t; @@ -347,8 +353,8 @@ type env = { prefix: string list; heap_structs: Idents.LidSet.t; pointer_holding_structs: Idents.LidSet.t; - (* A map from lid (type name) to the list of fields for that struct. *) struct_fields: MiniRust.struct_field list DataTypeMap.t; + derives: TraitSet.t LidMap.t; location: location; } @@ -361,6 +367,7 @@ let empty heap_structs pointer_holding_structs = { struct_fields = DataTypeMap.empty; heap_structs; pointer_holding_structs; + derives = LidMap.empty; location = empty_loc; } @@ -1160,9 +1167,45 @@ let is_handled_primitively = function | _ -> false +let compute_derives heap_structs _pointer_holding_structs files = + let definitions = List.fold_left (fun map (_, decls) -> + List.fold_left (fun map decl -> LidMap.add (Ast.lid_of_decl decl) decl map) map decls + ) LidMap.empty files in + + let everything = TraitSet.of_list [ MiniRust.PartialEq; Clone; Copy ] in + + let module F = Fix.Fix.ForOrderedType(struct + type t = Ast.lident + let compare = compare + end)(struct + type property = TraitSet.t + let bottom = everything + let equal = (=) + let is_maximal _ = false + end) in + + let equations lid valuation = + let traits = object + inherit [_] Ast.reduce + method zero = everything + method plus = TraitSet.inter + method! visit_TQualified _ lid = + valuation lid + end#visit_decl () (LidMap.find lid definitions) + in + if LidSet.mem lid heap_structs then + (* If this type will contain a Box<...> then it cannot have trait copy. *) + TraitSet.diff traits (TraitSet.of_list [ MiniRust.Copy ]) + else + traits + in + + F.lfp equations + + (* In Rust, like in C, all the declarations from the current module are in * scope immediately. This requires us to duplicate a little bit of work. *) -let bind_decl env (d: Ast.decl): env = +let bind_decl env trait_valuation (d: Ast.decl): env = match d with | DFunction (_, _, _, _, _, lid, _, _) when is_handled_primitively lid -> env @@ -1220,6 +1263,7 @@ let bind_decl env (d: Ast.decl): env = let env = push_type env lid name in env, name in + let derives = trait_valuation lid in match decl with | Flat fields -> (* These sets are mutually exclusive, so we don't box *and* introduce a @@ -1237,7 +1281,9 @@ let bind_decl env (d: Ast.decl): env = let f = Option.get f in { MiniRust.name = f; visibility = Some Pub; typ = translate_type_with_config env { box; lifetime } t } ) fields in - { env with struct_fields = DataTypeMap.add (`Struct lid) fields env.struct_fields } + { env with + struct_fields = DataTypeMap.add (`Struct lid) fields env.struct_fields; + derives = LidMap.add lid derives env.derives } | Variant branches -> let box = Idents.LidSet.mem lid env.heap_structs in let lifetime = Idents.LidSet.mem lid env.pointer_holding_structs in @@ -1254,8 +1300,12 @@ let bind_decl env (d: Ast.decl): env = { MiniRust.name = f; visibility = Some Pub; typ = translate_type_with_config env { box; lifetime } t } ) fields in - { env with struct_fields = DataTypeMap.add cons_lid fields env.struct_fields } + { env with + struct_fields = DataTypeMap.add cons_lid fields env.struct_fields; + derives = LidMap.add lid derives env.derives } ) env branches + | Enum _ -> + { env with derives = LidMap.add lid derives env.derives } | _ -> env @@ -1337,11 +1387,13 @@ let translate_decl env (d: Ast.decl): MiniRust.decl option = in let generic_params = match lifetime with Some l -> [ MiniRust.Lifetime l ] | None -> [] in let fields = DataTypeMap.find (`Struct lid) env.struct_fields in - Some (Struct { name; meta; fields; generic_params }) + let derives = List.of_seq (TraitSet.to_seq (LidMap.find lid env.derives)) in + Some (Struct { name; meta; fields; generic_params; derives }) | Enum idents -> (* No need to do name binding here since there are entirely resolved via the type name. *) let items = List.map (fun i -> snd i, None) idents in - Some (Enumeration { name; meta; items; derives = [ PartialEq; Clone; Copy ]; generic_params = [] }) + let derives = List.of_seq (TraitSet.to_seq (LidMap.find lid env.derives)) in + Some (Enumeration { name; meta; items; derives; generic_params = [] }) | Abbrev t -> let has_inner_pointer = (object inherit [_] Ast.reduce @@ -1372,7 +1424,8 @@ let translate_decl env (d: Ast.decl): MiniRust.decl option = let fields = List.map (fun (x: MiniRust.struct_field) -> { x with visibility = None }) fields in cons, Some fields ) branches in - Some (Enumeration { name; meta; items; derives = [ PartialEq; Clone; Copy ]; generic_params }) + let derives = List.of_seq (TraitSet.to_seq (LidMap.find lid env.derives)) in + Some (Enumeration { name; meta; items; derives; generic_params }) | Union _ -> Warn.failwith "TODO: Ast.DType (%a)\n" PrintAst.Ops.plid lid | Forward _ -> @@ -1497,6 +1550,7 @@ let compute_struct_info files = let translate_files files = let heap_structs, pointer_holding_structs = compute_struct_info files in + let derives = compute_derives heap_structs pointer_holding_structs files in if Options.debug "rs-structs" then begin KPrint.bprintf "The following types are understood to be heap-allocated:\n"; List.iter (KPrint.bprintf " %a\n" PrintAst.Ops.plid) (Idents.LidSet.elements heap_structs) @@ -1525,7 +1579,7 @@ let translate_files files = (* Step 1: bind all declarations and add them to the environment with their types *) let env = List.fold_left (fun env d -> try - bind_decl env d + bind_decl env derives d with e -> (* We do not increase failures as this will be counted below. *) KPrint.bprintf "%sERROR translating type of %a: %s%s\n%s\n" Ansi.red diff --git a/lib/MiniRust.ml b/lib/MiniRust.ml index c647bd4d..902c050c 100644 --- a/lib/MiniRust.ml +++ b/lib/MiniRust.ml @@ -184,6 +184,7 @@ type decl = | Struct of { name: name; fields: struct_field list; + derives: trait list; meta: meta; generic_params: generic_param list; } diff --git a/lib/PrintMiniRust.ml b/lib/PrintMiniRust.ml index 241194b4..017a0c82 100644 --- a/lib/PrintMiniRust.ml +++ b/lib/PrintMiniRust.ml @@ -587,8 +587,9 @@ let rec print_decl env (d: decl) = | None -> empty | Some item_struct -> break1 ^^ braces_with_nesting (print_struct env item_struct) ) items) - | Struct { fields; meta; generic_params; _ } -> + | Struct { fields; meta; generic_params; derives; _ } -> group @@ + group (print_derives derives) ^/^ group (print_meta meta ^^ string "struct" ^/^ string target_name ^^ print_generic_params generic_params) ^/^ braces_with_nesting (print_struct env fields) | Alias { generic_params; body; meta; _ } ->