Skip to content

Commit

Permalink
Make the computation of derive traits automatic via another least fix…
Browse files Browse the repository at this point in the history
…ed point
  • Loading branch information
msprotz committed Oct 25, 2024
1 parent e98d333 commit c32bdf8
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 9 deletions.
70 changes: 62 additions & 8 deletions lib/AstToMiniRust.ml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
(* Low* to Rust backend *)

module LidMap = Idents.LidMap
module LidSet = Idents.LidSet

(* Location information *)

Expand Down Expand Up @@ -339,6 +340,11 @@ module DataTypeMap = Map.Make(struct
let compare = compare
end)

module TraitSet = Set.Make(struct
type t = MiniRust.trait
let compare = compare
end)

type env = {
decls: (MiniRust.name * MiniRust.typ) LidMap.t;
global_scope: NameSet.t;
Expand All @@ -347,8 +353,8 @@ type env = {
prefix: string list;
heap_structs: Idents.LidSet.t;
pointer_holding_structs: Idents.LidSet.t;
(* A map from lid (type name) to the list of fields for that struct. *)
struct_fields: MiniRust.struct_field list DataTypeMap.t;
derives: TraitSet.t LidMap.t;
location: location;
}

Expand All @@ -361,6 +367,7 @@ let empty heap_structs pointer_holding_structs = {
struct_fields = DataTypeMap.empty;
heap_structs;
pointer_holding_structs;
derives = LidMap.empty;
location = empty_loc;
}

Expand Down Expand Up @@ -1160,9 +1167,45 @@ let is_handled_primitively = function
| _ ->
false

let compute_derives heap_structs _pointer_holding_structs files =
let definitions = List.fold_left (fun map (_, decls) ->
List.fold_left (fun map decl -> LidMap.add (Ast.lid_of_decl decl) decl map) map decls
) LidMap.empty files in

let everything = TraitSet.of_list [ MiniRust.PartialEq; Clone; Copy ] in

let module F = Fix.Fix.ForOrderedType(struct
type t = Ast.lident
let compare = compare
end)(struct
type property = TraitSet.t
let bottom = everything
let equal = (=)
let is_maximal _ = false
end) in

let equations lid valuation =
let traits = object
inherit [_] Ast.reduce
method zero = everything
method plus = TraitSet.inter
method! visit_TQualified _ lid =
valuation lid
end#visit_decl () (LidMap.find lid definitions)
in
if LidSet.mem lid heap_structs then
(* If this type will contain a Box<...> then it cannot have trait copy. *)
TraitSet.diff traits (TraitSet.of_list [ MiniRust.Copy ])
else
traits
in

F.lfp equations


(* In Rust, like in C, all the declarations from the current module are in
* scope immediately. This requires us to duplicate a little bit of work. *)
let bind_decl env (d: Ast.decl): env =
let bind_decl env trait_valuation (d: Ast.decl): env =
match d with
| DFunction (_, _, _, _, _, lid, _, _) when is_handled_primitively lid ->
env
Expand Down Expand Up @@ -1220,6 +1263,7 @@ let bind_decl env (d: Ast.decl): env =
let env = push_type env lid name in
env, name
in
let derives = trait_valuation lid in
match decl with
| Flat fields ->
(* These sets are mutually exclusive, so we don't box *and* introduce a
Expand All @@ -1237,7 +1281,9 @@ let bind_decl env (d: Ast.decl): env =
let f = Option.get f in
{ MiniRust.name = f; visibility = Some Pub; typ = translate_type_with_config env { box; lifetime } t }
) fields in
{ env with struct_fields = DataTypeMap.add (`Struct lid) fields env.struct_fields }
{ env with
struct_fields = DataTypeMap.add (`Struct lid) fields env.struct_fields;
derives = LidMap.add lid derives env.derives }
| Variant branches ->
let box = Idents.LidSet.mem lid env.heap_structs in
let lifetime = Idents.LidSet.mem lid env.pointer_holding_structs in
Expand All @@ -1254,8 +1300,12 @@ let bind_decl env (d: Ast.decl): env =
{ MiniRust.name = f; visibility = Some Pub; typ = translate_type_with_config env { box; lifetime } t }
) fields
in
{ env with struct_fields = DataTypeMap.add cons_lid fields env.struct_fields }
{ env with
struct_fields = DataTypeMap.add cons_lid fields env.struct_fields;
derives = LidMap.add lid derives env.derives }
) env branches
| Enum _ ->
{ env with derives = LidMap.add lid derives env.derives }
| _ ->
env

Expand Down Expand Up @@ -1337,11 +1387,13 @@ let translate_decl env (d: Ast.decl): MiniRust.decl option =
in
let generic_params = match lifetime with Some l -> [ MiniRust.Lifetime l ] | None -> [] in
let fields = DataTypeMap.find (`Struct lid) env.struct_fields in
Some (Struct { name; meta; fields; generic_params })
let derives = List.of_seq (TraitSet.to_seq (LidMap.find lid env.derives)) in
Some (Struct { name; meta; fields; generic_params; derives })
| Enum idents ->
(* No need to do name binding here since there are entirely resolved via the type name. *)
let items = List.map (fun i -> snd i, None) idents in
Some (Enumeration { name; meta; items; derives = [ PartialEq; Clone; Copy ]; generic_params = [] })
let derives = List.of_seq (TraitSet.to_seq (LidMap.find lid env.derives)) in
Some (Enumeration { name; meta; items; derives; generic_params = [] })
| Abbrev t ->
let has_inner_pointer = (object
inherit [_] Ast.reduce
Expand Down Expand Up @@ -1372,7 +1424,8 @@ let translate_decl env (d: Ast.decl): MiniRust.decl option =
let fields = List.map (fun (x: MiniRust.struct_field) -> { x with visibility = None }) fields in
cons, Some fields
) branches in
Some (Enumeration { name; meta; items; derives = [ PartialEq; Clone; Copy ]; generic_params })
let derives = List.of_seq (TraitSet.to_seq (LidMap.find lid env.derives)) in
Some (Enumeration { name; meta; items; derives; generic_params })
| Union _ ->
Warn.failwith "TODO: Ast.DType (%a)\n" PrintAst.Ops.plid lid
| Forward _ ->
Expand Down Expand Up @@ -1497,6 +1550,7 @@ let compute_struct_info files =

let translate_files files =
let heap_structs, pointer_holding_structs = compute_struct_info files in
let derives = compute_derives heap_structs pointer_holding_structs files in
if Options.debug "rs-structs" then begin
KPrint.bprintf "The following types are understood to be heap-allocated:\n";
List.iter (KPrint.bprintf " %a\n" PrintAst.Ops.plid) (Idents.LidSet.elements heap_structs)
Expand Down Expand Up @@ -1525,7 +1579,7 @@ let translate_files files =
(* Step 1: bind all declarations and add them to the environment with their types *)
let env = List.fold_left (fun env d ->
try
bind_decl env d
bind_decl env derives d
with e ->
(* We do not increase failures as this will be counted below. *)
KPrint.bprintf "%sERROR translating type of %a: %s%s\n%s\n" Ansi.red
Expand Down
1 change: 1 addition & 0 deletions lib/MiniRust.ml
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ type decl =
| Struct of {
name: name;
fields: struct_field list;
derives: trait list;
meta: meta;
generic_params: generic_param list;
}
Expand Down
3 changes: 2 additions & 1 deletion lib/PrintMiniRust.ml
Original file line number Diff line number Diff line change
Expand Up @@ -587,8 +587,9 @@ let rec print_decl env (d: decl) =
| None -> empty
| Some item_struct -> break1 ^^ braces_with_nesting (print_struct env item_struct)
) items)
| Struct { fields; meta; generic_params; _ } ->
| Struct { fields; meta; generic_params; derives; _ } ->
group @@
group (print_derives derives) ^/^
group (print_meta meta ^^ string "struct" ^/^ string target_name ^^ print_generic_params generic_params) ^/^
braces_with_nesting (print_struct env fields)
| Alias { generic_params; body; meta; _ } ->
Expand Down

0 comments on commit c32bdf8

Please sign in to comment.