Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
msprotz committed Oct 23, 2024
1 parent 918bf50 commit ae91beb
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 86 deletions.
206 changes: 120 additions & 86 deletions lib/OptimizeMiniRust.ml
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ let retrieve_pair_type = function
| Tuple [e1; e2] -> assert (e1 = e2); e1
| _ -> failwith "impossible: retrieve_pair_type"

let rec infer (env: env) (expected: typ) (known: known) (e: expr): known * expr =
let rec infer (env: env) recurse (expected: typ) (known: known) (e: expr): known * expr =
if Options.debug "rs-mut" then
KPrint.bprintf "[infer] %a @ %a\n" pexpr e ptyp expected;
match e with
Expand All @@ -126,7 +126,7 @@ let rec infer (env: env) (expected: typ) (known: known) (e: expr): known * expr
| Open { atom; _ } ->
add_mut_var atom known, Borrow (Mut, e)
| Index (e1, (Range _ as r)) ->
let known, e1 = infer env expected known e1 in
let known, e1 = infer env recurse expected known e1 in
known, Borrow (Mut, Index (e1, r))

| Field (Open _, "0", None)
Expand All @@ -145,7 +145,7 @@ let rec infer (env: env) (expected: typ) (known: known) (e: expr): known * expr
KPrint.bprintf "[infer-mut, borrow] borrwing %a is not supported\n" pexpr e;
failwith "TODO: borrowing something other than a variable"
else
let known, e = infer env (assert_borrow expected) known e in
let known, e = infer env recurse (assert_borrow expected) known e in
known, Borrow (k, e)

| Open { atom; _ } ->
Expand All @@ -160,14 +160,14 @@ let rec infer (env: env) (expected: typ) (known: known) (e: expr): known * expr
(* KPrint.bprintf "[infer-mut,let] %a\n" pexpr e; *)
let a, e2 = open_ b e2 in
(* KPrint.bprintf "[infer-mut,let] opened %s[%s]\n" b.name (show_atom_t a); *)
let known, e2 = infer env expected known e2 in
let known, e2 = infer env recurse expected known e2 in
let mut_var = want_mut_var a known in
let mut_borrow = want_mut_borrow a known in
(* KPrint.bprintf "[infer-mut,let-done-e2] %s[%s]: %a let mut ? %b &mut ? %b\n" b.name *)
(* (show_atom_t a) *)
(* ptyp b.typ mut_var mut_borrow; *)
let t1 = if mut_borrow then make_mut_borrow b.typ else b.typ in
let known, e1 = infer env t1 known e1 in
let known, e1 = infer env recurse t1 known e1 in
known, Let ({ b with mut = mut_var; typ = t1 }, e1, close a (Var 0) (lift 1 e2))

| Call (Name n, targs, es) ->
Expand All @@ -176,7 +176,7 @@ let rec infer (env: env) (expected: typ) (known: known) (e: expr): known * expr
function that gets instantiated with a reference type *)
let ts = NameMap.find n env.seen in
let known, es = List.fold_left2 (fun (known, es) e t ->
let known, e = infer env t known e in
let known, e = infer env recurse t known e in
known, e :: es
) (known, []) es ts
in
Expand All @@ -186,22 +186,24 @@ let rec infer (env: env) (expected: typ) (known: known) (e: expr): known * expr
(* Since we do not have type-level substitutions in MiniRust, we special-case ignore here.
Ideally, it would be added to builtins with `Bound 0` as a suitable type for the
argument. *)
let known, e = infer env (KList.one targs) known (KList.one es) in
let known, e = infer env recurse (KList.one targs) known (KList.one es) in
known, Call (Name n, targs, [ e ])
else if n = ["Box"; "new"] then
let known, e = infer env (KList.one targs) known (KList.one es) in
let known, e = infer env recurse (KList.one targs) known (KList.one es) in
known, Call (Name n, targs, [ e ])
else if n = [ "lib"; "memzero0"; "memzero" ] then (
(* Same as ignore above *)
assert (List.length es = 2);
let e1, e2 = KList.two es in
let known, e1 = infer env (Ref (None, Mut, Slice (KList.one targs))) known e1 in
let known, e2 = infer env u32 known e2 in
let known, e1 = infer env recurse (Ref (None, Mut, Slice (KList.one targs))) known e1 in
let known, e2 = infer env recurse u32 known e2 in
known, Call (Name n, targs, [ e1; e2 ])
) else (
KPrint.bprintf "[infer-mut,call] recursing on %s\n" (String.concat " :: " n);
debug env;
failwith "TODO: recursion or missing function"
if Options.debug "rs-mut" then begin
KPrint.bprintf "[infer-mut,call] recursing on %s\n" (String.concat " :: " n);
debug env
end;
recurse env n
)

| Call (Operator o, [], _) -> begin match o with
Expand All @@ -222,7 +224,7 @@ let rec infer (env: env) (expected: typ) (known: known) (e: expr): known * expr
(* atom = e3 *)
| Assign (Open { atom; _ }, e3, t) ->
(* KPrint.bprintf "[infer-mut,assign] %a\n" pexpr e; *)
let known, e3 = infer env t known e3 in
let known, e3 = infer env recurse t known e3 in
add_mut_var atom known, e3

(* atom[e2] = e2 *)
Expand All @@ -240,23 +242,23 @@ let rec infer (env: env) (expected: typ) (known: known) (e: expr): known * expr
| Assign (Index (Field (Open {atom;_}, "1", None) as e1, e2), e3, t) ->
(* KPrint.bprintf "[infer-mut,assign] %a\n" pexpr e; *)
let known = add_mut_borrow atom known in
let known, e2 = infer env usize known e2 in
let known, e3 = infer env t known e3 in
let known, e2 = infer env recurse usize known e2 in
let known, e3 = infer env recurse t known e3 in
known, Assign (Index (e1, e2), e3, t)

(* (x.f)[e2] = e3 *)
| Assign (Index (Field (_, f, st (* optional type *)) as e1, e2), e3, t) ->
let known = add_mut_field st f known in
let known, e2 = infer env usize known e2 in
let known, e3 = infer env t known e3 in
let known, e2 = infer env recurse usize known e2 in
let known, e3 = infer env recurse t known e3 in
known, Assign (Index (e1, e2), e3, t)

(* (&atom)[e2] = e3 *)
| Assign (Index (Borrow (_, (Open { atom; _ } as e1)), e2), e3, t) ->
(* KPrint.bprintf "[infer-mut,assign] %a\n" pexpr e; *)
let known = add_mut_var atom known in
let known, e2 = infer env usize known e2 in
let known, e3 = infer env t known e3 in
let known, e2 = infer env recurse usize known e2 in
let known, e3 = infer env recurse t known e3 in
known, Assign (Index (Borrow (Mut, e1), e2), e3, t)

| Assign (Field (_, "0", None), _, _)
Expand All @@ -266,30 +268,30 @@ let rec infer (env: env) (expected: typ) (known: known) (e: expr): known * expr
(* (atom[e2]).f = e3 *)
| Assign (Field (Index ((Open {atom; _} as e1), e2), f, st), e3, t) ->
let known = add_mut_borrow atom known in
let known, e2 = infer env usize known e2 in
let known, e3 = infer env t known e3 in
let known, e2 = infer env recurse usize known e2 in
let known, e3 = infer env recurse t known e3 in
known, Assign (Field (Index (e1, e2), f, st), e3, t)

(* (&n)[e2] = e3 *)
| Assign (Index (Borrow (_, Name n), e2), e3, t) ->
(* This case should only occur for globals. For now, we simply mutably borrow it *)
let known, e2 = infer env usize known e2 in
let known, e3 = infer env t known e3 in
let known, e2 = infer env recurse usize known e2 in
let known, e3 = infer env recurse t known e3 in
known, Assign (Index (Borrow (Mut, Name n), e2), e3, t)

(* (&(&atom)[e2])[e3] = e4 *)
| Assign (Index (Borrow (_, Index (Borrow (_, (Open {atom; _} as e1)), e2)), e3), e4, t) ->
let known = add_mut_var atom known in
let known, e2 = infer env usize known e2 in
let known, e3 = infer env usize known e3 in
let known, e4 = infer env t known e4 in
let known, e2 = infer env recurse usize known e2 in
let known, e3 = infer env recurse usize known e3 in
let known, e4 = infer env recurse t known e4 in
known, Assign (Index (Borrow (Mut, Index (Borrow (Mut, e1), e2)), e3), e4, t)

(* (&(atom.f))[e1] = e2 *)
| Assign (Index (Borrow (_, Field (Open {atom; _} as e1, f, t)), e2), e3, t1) ->
let known = add_mut_var atom known in
let known, e2 = infer env usize known e2 in
let known, e3 = infer env usize known e3 in
let known, e2 = infer env recurse usize known e2 in
let known, e3 = infer env recurse usize known e3 in
known, Assign (Index (Borrow (Mut, Field (e1, f, t)), e2), e3, t1)

| Assign _ ->
Expand All @@ -309,12 +311,12 @@ let rec infer (env: env) (expected: typ) (known: known) (e: expr): known * expr
known, e

| IfThenElse (e1, e2, e3) ->
let known, e1 = infer env bool known e1 in
let known, e2 = infer env expected known e2 in
let known, e1 = infer env recurse bool known e1 in
let known, e2 = infer env recurse expected known e2 in
let known, e3 =
match e3 with
| Some e3 ->
let known, e3 = infer env expected known e3 in
let known, e3 = infer env recurse expected known e3 in
known, Some e3
| None ->
known, None
Expand All @@ -323,15 +325,15 @@ let rec infer (env: env) (expected: typ) (known: known) (e: expr): known * expr

| As (e, t) ->
(* Not really correct, but As is only used for integer casts *)
let known, e = infer env t known e in
let known, e = infer env recurse t known e in
known, As (e, t)

| For (b, e1, e2) ->
let known, e2 = infer env Unit known e2 in
let known, e2 = infer env recurse Unit known e2 in
known, For (b, e1, e2)

| While (e1, e2) ->
let known, e2 = infer env Unit known e2 in
let known, e2 = infer env recurse Unit known e2 in
known, While (e1, e2)

| MethodCall (e1, m, e2) ->
Expand All @@ -347,9 +349,9 @@ let rec infer (env: env) (expected: typ) (known: known) (e: expr): known * expr
known, MethodCall (e1, m, e2)
| ["split_at"] ->
assert (List.length e2 = 1);
let known, e2 = infer env usize known (List.hd e2) in
let known, e2 = infer env recurse usize known (List.hd e2) in
let t1 = retrieve_pair_type expected in
let known, e1 = infer env t1 known e1 in
let known, e1 = infer env recurse t1 known e1 in
if is_mut_borrow expected then
known, MethodCall (e1, ["split_at_mut"], [e2])
else
Expand All @@ -359,8 +361,8 @@ let rec infer (env: env) (expected: typ) (known: known) (e: expr): known * expr
assert (List.length e2 = 1);
(* We do not have access to the types of e1 and e2. However, the concrete
type should not matter during mut inference, we thus use Unit as a default *)
let known, dst = infer env (Ref (None, Mut, Unit)) known dst in
let known, e2 = infer env (Ref (None, Shared, Unit)) known (List.hd e2) in
let known, dst = infer env recurse (Ref (None, Mut, Unit)) known dst in
let known, e2 = infer env recurse (Ref (None, Shared, Unit)) known (List.hd e2) in
known, MethodCall (Index (dst, range), m, [e2])
(* The AstToMiniRust translation should always introduce an index
as the left argument of copy_from_slice *)
Expand Down Expand Up @@ -388,10 +390,10 @@ let rec infer (env: env) (expected: typ) (known: known) (e: expr): known * expr

| Match (e_scrut, t, arms) as _e_match ->
(* We have the expected type of the scrutinee: recurse *)
let known, e = infer env t known e_scrut in
let known, e = infer env recurse t known e_scrut in
let known, arms = List.fold_left_map (fun known ((bs, _, _) as branch) ->
let atoms, pat, e = open_branch branch in
let known, e = infer env expected known e in
let known, e = infer env recurse expected known e in
(* Given a pattern p of type t, and a known map:
i. if the pattern contains f = x *and* x is in R, then the field f of
the struct type (provided by the context t) needs to be mutable --
Expand Down Expand Up @@ -499,8 +501,8 @@ let rec infer (env: env) (expected: typ) (known: known) (e: expr): known * expr
earlier. This should therefore only occur when accessing a variable
in an array *)
let expected = Ref (None, Shared, expected) in
let known, e1 = infer env expected known e1 in
let known, e2 = infer env usize known e2 in
let known, e1 = infer env recurse expected known e1 in
let known, e2 = infer env recurse usize known e2 in
known, Index (e1, e2)

(* Special case for array slices. This occurs, e.g., when calling a function with
Expand Down Expand Up @@ -867,52 +869,84 @@ let infer_mut_borrows files =
(* Map.of_list is only available from OCaml 5.1 onwards *)
let env = { seen = List.to_seq builtins |> NameMap.of_seq; structs = DataTypeMap.empty } in
let known = { structs = DataTypeMap.empty; v = VarSet.empty; r = VarSet.empty; p = VarSet.empty } in

(* We must do a graph traversal since functions are potentially mutually-recursive at file scope *)
let module T = struct type color = White | Gray | Black end in
let open T in

let rec memoize map visit env lid =
let color, body = Hashtbl.find map lid in
match color with
| Gray ->
Warn.fatal_error "[Frames]: cyclic dependency on %a" PrintMiniRust.pname lid
| Black ->
env, body
| White ->
Hashtbl.replace map lid (Gray, body);
let env, body = visit env (memoize map visit) body in
Hashtbl.replace map lid (Black, body);
env, body
in

let map =
let map = Hashtbl.create 41 in
List.iter (fun (_, decls) ->
List.iter (fun d -> Hashtbl.add map (name_of_decl d) (White, d)) decls
) files;
map
in

let visit_one = memoize map (fun (env: env) recurse decl ->
match decl with
| Function ({ name; body; return_type; parameters; _ } as f) ->
if Options.debug "rs-mut" then
KPrint.bprintf "[infer-mut] visiting %s\n" (String.concat "::" name);
let atoms, body =
List.fold_right (fun binder (atoms, e) ->
let a, e = open_ binder e in
(* KPrint.bprintf "[infer-mut] opened %s[%s]\n%a\n" binder.name (show_atom_t a) pexpr e; *)
a :: atoms, e
) parameters ([], body)
in
(* KPrint.bprintf "[infer-mut] done opening %s\n%a\n" (String.concat "." name)
pexpr body; *)
(* Start the analysis with the current state of struct mutability *)
let known, body = infer env recurse return_type {known with structs = env.structs} body in
let parameters, body =
List.fold_left2 (fun (parameters, e) (binder: binding) atom ->
let e = close atom (Var 0) (lift 1 e) in
(* KPrint.bprintf "[infer-mut] closed %s[%s]\n%a\n" binder.name (show_atom_t atom) pexpr e; *)
let mut = want_mut_var atom known in
let typ = if want_mut_borrow atom known then make_mut_borrow binder.typ else binder.typ in
{ binder with mut; typ } :: parameters, e
) ([], body) parameters atoms
in
let parameters = List.rev parameters in
(* We update the environment in two ways. First, we add the function declaration,
with the mutability of the parameters inferred during the analysis.
Second, we propagate the information about the mutability of struct fields
inferred while traversing this function to the global environment. Note, since
the traversal does not add or remove any bindings, but only increases the
mutability, we can do a direct replacement instead of a more complex merge *)
let env = { seen = NameMap.add name (List.map (fun (x: binding) -> x.typ) parameters) env.seen; structs = known.structs } in
env, Function { f with body; parameters }
| Struct ({name; fields; _}) ->
{ env with structs = DataTypeMap.add (`Struct name) fields env.structs }, decl
| Enumeration { name; items; _ } ->
List.fold_left (fun (env: env) (cons, fields) ->
match fields with
| None -> env
| Some fields -> { env with structs = DataTypeMap.add (`Variant (name, cons)) fields env.structs }
) env items, decl
| _ ->
env, decl
) in

let env, files =
List.fold_left (fun (env, files) (filename, decls) ->
let env, decls = List.fold_left (fun (env, decls) decl ->
match decl with
| Function ({ name; body; return_type; parameters; _ } as f) ->
if Options.debug "rs-mut" then
KPrint.bprintf "[infer-mut] visiting %s\n" (String.concat "::" name);
let atoms, body =
List.fold_right (fun binder (atoms, e) ->
let a, e = open_ binder e in
(* KPrint.bprintf "[infer-mut] opened %s[%s]\n%a\n" binder.name (show_atom_t a) pexpr e; *)
a :: atoms, e
) parameters ([], body)
in
(* KPrint.bprintf "[infer-mut] done opening %s\n%a\n" (String.concat "." name)
pexpr body; *)
(* Start the analysis with the current state of struct mutability *)
let known, body = infer env return_type {known with structs = env.structs} body in
let parameters, body =
List.fold_left2 (fun (parameters, e) (binder: binding) atom ->
let e = close atom (Var 0) (lift 1 e) in
(* KPrint.bprintf "[infer-mut] closed %s[%s]\n%a\n" binder.name (show_atom_t atom) pexpr e; *)
let mut = want_mut_var atom known in
let typ = if want_mut_borrow atom known then make_mut_borrow binder.typ else binder.typ in
{ binder with mut; typ } :: parameters, e
) ([], body) parameters atoms
in
let parameters = List.rev parameters in
(* We update the environment in two ways. First, we add the function declaration,
with the mutability of the parameters inferred during the analysis.
Second, we propagate the information about the mutability of struct fields
inferred while traversing this function to the global environment. Note, since
the traversal does not add or remove any bindings, but only increases the
mutability, we can do a direct replacement instead of a more complex merge *)
let env = { seen = NameMap.add name (List.map (fun (x: binding) -> x.typ) parameters) env.seen; structs = known.structs } in
env, Function { f with body; parameters } :: decls
| Struct ({name; fields; _}) ->
{ env with structs = DataTypeMap.add (`Struct name) fields env.structs }, decl :: decls
| Enumeration { name; items; _ } ->
List.fold_left (fun (env: env) (cons, fields) ->
match fields with
| None -> env
| Some fields -> { env with structs = DataTypeMap.add (`Variant (name, cons)) fields env.structs }
) env items, decl :: decls
| _ ->
env, decl :: decls
let env, decl = visit_one env (name_of_decl decl) in
env, decl :: decls
) (env, []) decls in
let decls = List.rev decls in
env, (filename, decls) :: files
Expand Down
1 change: 1 addition & 0 deletions lib/PrintMiniRust.ml
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,7 @@ let print_decls ns ds =
) env.globals;
separate (hardline ^^ hardline) ds ^^ hardline

let pname = printf_of_pprint (print_name debug)
let pexpr = printf_of_pprint (print_expr debug max_int)
let ptyp = printf_of_pprint (print_typ debug)
let ptyps = printf_of_pprint (separate_map (comma ^^ break1) (print_typ debug))
Expand Down
1 change: 1 addition & 0 deletions src/Karamel.ml
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,7 @@ Supported options:|}
* checking it. Note that bundling calls [drop_unused] already to do a first
* round of unused code elimination! *)
let files = Bundles.make_bundles files in
print PrintAst.print_files files;
let has_spinlock = List.exists (fun (_, ds) ->
List.exists (fun d ->
fst (Ast.lid_of_decl d) = [ "Steel"; "SpinLock" ]
Expand Down

0 comments on commit ae91beb

Please sign in to comment.