Skip to content

Commit

Permalink
Compiler: Effects: keep track of CPS calls (#1648)
Browse files Browse the repository at this point in the history
Co-authored-by: Jérôme Vouillon <[email protected]>
  • Loading branch information
OlivierNicole and vouillon authored Aug 3, 2024
1 parent c484610 commit 285902b
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 44 deletions.
14 changes: 8 additions & 6 deletions compiler/lib/driver.ml
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ let phi p =

let ( +> ) f g x = g (f x)

let map_fst f (x, y) = f x, y
let map_fst f (x, y, z) = f x, y, z

let effects ~deadcode_sentinal p =
if Config.Flag.effects ()
Expand All @@ -104,9 +104,11 @@ let effects ~deadcode_sentinal p =
Deadcode.f p
else p, live_vars
in
let p, cps = p |> Effects.f ~flow_info:info ~live_vars +> map_fst Lambda_lifting.f in
p, cps)
else p, (Code.Var.Set.empty : Effects.cps_calls)
p |> Effects.f ~flow_info:info ~live_vars +> map_fst Lambda_lifting.f)
else
( p
, (Code.Var.Set.empty : Effects.trampolined_calls)
, (Code.Var.Set.empty : Effects.in_cps) )

let exact_calls profile ~deadcode_sentinal p =
if not (Config.Flag.effects ())
Expand Down Expand Up @@ -193,14 +195,14 @@ let generate
~wrap_with_fun
~warn_on_unhandled_effect
~deadcode_sentinal
((p, live_vars), cps_calls) =
((p, live_vars), trampolined_calls, _) =
if times () then Format.eprintf "Start Generation...@.";
let should_export = should_export wrap_with_fun in
Generate.f
p
~exported_runtime
~live_vars
~cps_calls
~trampolined_calls
~should_export
~warn_on_unhandled_effect
~deadcode_sentinal
Expand Down
41 changes: 29 additions & 12 deletions compiler/lib/effects.ml
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,9 @@ let jump_closures blocks_to_transform idom : jump_closures =
idom
{ closure_of_jump = Addr.Map.empty; closures_of_alloc_site = Addr.Map.empty }

type cps_calls = Var.Set.t
type trampolined_calls = Var.Set.t

type in_cps = Var.Set.t

type st =
{ mutable new_blocks : Code.block Addr.Map.t * Code.Addr.t
Expand All @@ -263,7 +265,8 @@ type st =
; block_order : (Addr.t, int) Hashtbl.t
; live_vars : Deadcode.variable_uses
; flow_info : Global_flow.info
; cps_calls : cps_calls ref
; trampolined_calls : trampolined_calls ref
; in_cps : in_cps ref
}

let add_block st block =
Expand All @@ -280,10 +283,11 @@ let allocate_closure ~st ~params ~body ~branch loc =
let name = Var.fresh () in
[ Let (name, Closure (params, (pc, []))), loc ], name

let tail_call ~st ?(instrs = []) ~exact ~check ~f args loc =
let tail_call ~st ?(instrs = []) ~exact ~in_cps ~check ~f args loc =
assert (exact || check);
let ret = Var.fresh () in
if check then st.cps_calls := Var.Set.add ret !(st.cps_calls);
if check then st.trampolined_calls := Var.Set.add ret !(st.trampolined_calls);
if in_cps then st.in_cps := Var.Set.add ret !(st.in_cps);
instrs @ [ Let (ret, Apply { f; args; exact }), loc ], (Return ret, loc)

let cps_branch ~st ~src (pc, args) loc =
Expand All @@ -302,7 +306,15 @@ let cps_branch ~st ~src (pc, args) loc =
(* We check the stack depth only for backward edges (so, at
least once per loop iteration) *)
let check = Hashtbl.find st.block_order src >= Hashtbl.find st.block_order pc in
tail_call ~st ~instrs ~exact:true ~check ~f:(closure_of_pc ~st pc) args loc
tail_call
~st
~instrs
~exact:true
~in_cps:false
~check
~f:(closure_of_pc ~st pc)
args
loc

let cps_jump_cont ~st ~src ((pc, _) as cont) loc =
match Addr.Set.mem pc st.blocks_to_transform with
Expand Down Expand Up @@ -365,7 +377,7 @@ let cps_last ~st ~alloc_jump_closures pc ((last, last_loc) : last * loc) ~k :
(* Is the number of successive 'returns' is unbounded is CPS, it
means that we have an unbounded of calls in direct style
(even with tail call optimization) *)
tail_call ~st ~exact:true ~check:false ~f:k [ x ] last_loc
tail_call ~st ~exact:true ~in_cps:false ~check:false ~f:k [ x ] last_loc
| Raise (x, rmode) -> (
assert (List.is_empty alloc_jump_closures);
match Hashtbl.find_opt st.matching_exn_handler pc with
Expand Down Expand Up @@ -401,6 +413,7 @@ let cps_last ~st ~alloc_jump_closures pc ((last, last_loc) : last * loc) ~k :
~instrs:
((Let (exn_handler, Prim (Extern "caml_pop_trap", [])), noloc) :: instrs)
~exact:true
~in_cps:false
~check:false
~f:exn_handler
[ x ]
Expand Down Expand Up @@ -463,6 +476,7 @@ let cps_instr ~st (instr : instr) : instr =
(* Add the continuation parameter, and change the initial block if
needed *)
let k, cont = Hashtbl.find st.closure_info pc in
st.in_cps := Var.Set.add x !(st.in_cps);
Let (x, Closure (params @ [ k ], cont))
| Let (x, Prim (Extern "caml_alloc_dummy_function", [ size; arity ])) -> (
match arity with
Expand Down Expand Up @@ -532,7 +546,7 @@ let cps_block ~st ~k pc block =
let exact =
exact || Global_flow.exact_call st.flow_info f (List.length args)
in
tail_call ~st ~exact ~check:true ~f (args @ [ k ]) loc)
tail_call ~st ~exact ~in_cps:true ~check:true ~f (args @ [ k ]) loc)
| Prim (Extern "%resume", [ Pv stack; Pv f; Pv arg ]) ->
Some
(fun ~k ->
Expand All @@ -542,6 +556,7 @@ let cps_block ~st ~k pc block =
~instrs:
[ Let (k', Prim (Extern "caml_resume_stack", [ Pv stack; Pv k ])), noloc ]
~exact:(Global_flow.exact_call st.flow_info f 1)
~in_cps:true
~check:true
~f
[ arg; k' ]
Expand Down Expand Up @@ -599,7 +614,8 @@ let cps_block ~st ~k pc block =

let cps_transform ~live_vars ~flow_info ~cps_needed p =
let closure_info = Hashtbl.create 16 in
let cps_calls = ref Var.Set.empty in
let trampolined_calls = ref Var.Set.empty in
let in_cps = ref Var.Set.empty in
let p =
Code.fold_closures_innermost_first
p
Expand Down Expand Up @@ -658,7 +674,8 @@ let cps_transform ~live_vars ~flow_info ~cps_needed p =
; block_order = cfg.block_order
; flow_info
; live_vars
; cps_calls
; trampolined_calls
; in_cps
}
in
let function_needs_cps =
Expand Down Expand Up @@ -735,7 +752,7 @@ let cps_transform ~live_vars ~flow_info ~cps_needed p =
in
{ start = new_start; blocks; free_pc = new_start + 1 }
in
p, !cps_calls
p, !trampolined_calls, !in_cps

(****)

Expand Down Expand Up @@ -927,7 +944,7 @@ let f ~flow_info ~live_vars p =
let cps_needed = Partial_cps_analysis.f p flow_info in
let p, cps_needed = rewrite_toplevel ~cps_needed p in
let p = split_blocks ~cps_needed p in
let p, cps_calls = cps_transform ~live_vars ~flow_info ~cps_needed p in
let p, trampolined_calls, in_cps = cps_transform ~live_vars ~flow_info ~cps_needed p in
if Debug.find "times" () then Format.eprintf " effects: %a@." Timer.print t;
Code.invariant p;
p, cps_calls
p, trampolined_calls, in_cps
6 changes: 4 additions & 2 deletions compiler/lib/effects.mli
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
* Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
*)

type cps_calls = Code.Var.Set.t
type trampolined_calls = Code.Var.Set.t

val remove_empty_blocks : live_vars:Deadcode.variable_uses -> Code.program -> Code.program

type in_cps = Code.Var.Set.t

val f :
flow_info:Global_flow.info
-> live_vars:Deadcode.variable_uses
-> Code.program
-> Code.program * cps_calls
-> Code.program * trampolined_calls * in_cps
47 changes: 24 additions & 23 deletions compiler/lib/generate.ml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ type fall_through =
type application_description =
{ arity : int
; exact : bool
; cps : bool
; trampolined : bool
}

module Share = struct
Expand Down Expand Up @@ -133,7 +133,7 @@ module Share = struct
| _ -> t)

let get
~cps_calls
~trampolined_calls
?alias_strings
?(alias_prims = false)
?(alias_apply = true)
Expand All @@ -150,9 +150,9 @@ module Share = struct
match i with
| Let (_, Constant c) -> get_constant c share
| Let (x, Apply { args; exact; _ }) ->
let cps = Var.Set.mem x cps_calls in
if (not exact) || cps
then add_apply { arity = List.length args; exact; cps } share
let trampolined = Var.Set.mem x trampolined_calls in
if (not exact) || trampolined
then add_apply { arity = List.length args; exact; trampolined } share
else share
| Let (_, Special (Alias_prim name)) ->
let name = Primitive.resolve name in
Expand Down Expand Up @@ -244,11 +244,11 @@ module Share = struct
try J.EVar (AppMap.find desc t.vars.applies)
with Not_found ->
let x =
let { arity; exact; cps } = desc in
let { arity; exact; trampolined } = desc in
Var.fresh_n
(Printf.sprintf
"caml_%scall%d"
(match exact, cps with
(match exact, trampolined with
| true, false -> assert false
| true, true -> "cps_exact_"
| false, false -> ""
Expand All @@ -269,7 +269,7 @@ module Ctx = struct
; exported_runtime : (Code.Var.t * bool ref) option
; should_export : bool
; effect_warning : bool ref
; cps_calls : Effects.cps_calls
; trampolined_calls : Effects.trampolined_calls
; deadcode_sentinal : Var.t
; mutated_vars : Code.Var.Set.t Code.Addr.Map.t
; freevars : Code.Var.Set.t Code.Addr.Map.t
Expand All @@ -284,7 +284,7 @@ module Ctx = struct
~freevars
blocks
live
cps_calls
trampolined_calls
share
debug =
{ blocks
Expand All @@ -294,7 +294,7 @@ module Ctx = struct
; exported_runtime
; should_export
; effect_warning = ref (not warn_on_unhandled_effect)
; cps_calls
; trampolined_calls
; deadcode_sentinal
; mutated_vars
; freevars
Expand Down Expand Up @@ -773,7 +773,7 @@ let parallel_renaming back_edge params args continuation queue =

(****)

let apply_fun_raw ctx f params exact cps =
let apply_fun_raw ctx f params exact trampolined =
let n = List.length params in
let apply_directly =
(* Make sure we are performing a regular call, not a (slower)
Expand Down Expand Up @@ -801,7 +801,7 @@ let apply_fun_raw ctx f params exact cps =
, apply_directly
, J.call (runtime_fun ctx "caml_call_gen") [ f; J.array params ] J.N )
in
if cps
if trampolined
then (
assert (Config.Flag.effects ());
(* When supporting effect, we systematically perform tailcall
Expand All @@ -814,7 +814,7 @@ let apply_fun_raw ctx f params exact cps =
, J.call (runtime_fun ctx "caml_trampoline_return") [ f; J.array params ] J.N ))
else apply

let generate_apply_fun ctx { arity; exact; cps } =
let generate_apply_fun ctx { arity; exact; trampolined } =
let f' = Var.fresh_n "f" in
let f = J.V f' in
let params =
Expand All @@ -829,23 +829,24 @@ let generate_apply_fun ctx { arity; exact; cps } =
( None
, J.fun_
(f :: params)
[ J.Return_statement (Some (apply_fun_raw ctx f' params' exact cps)), J.N ]
[ J.Return_statement (Some (apply_fun_raw ctx f' params' exact trampolined)), J.N
]
J.N )

let apply_fun ctx f params exact cps loc =
let apply_fun ctx f params exact trampolined loc =
(* We always go through an intermediate function when doing CPS
calls. This function first checks the stack depth to prevent
a stack overflow. This makes the code smaller than inlining
the test, and we expect the performance impact to be low
since the function should get inlined by the JavaScript
engines. *)
if Config.Flag.inline_callgen () || (exact && not cps)
then apply_fun_raw ctx f params exact cps
if Config.Flag.inline_callgen () || (exact && not trampolined)
then apply_fun_raw ctx f params exact trampolined
else
let y =
Share.get_apply
(generate_apply_fun ctx)
{ arity = List.length params; exact; cps }
{ arity = List.length params; exact; trampolined }
ctx.Ctx.share
in
J.call y (f :: params) loc
Expand Down Expand Up @@ -1028,7 +1029,7 @@ let throw_statement ctx cx k loc =
let rec translate_expr ctx queue loc x e level : _ * J.statement_list =
match e with
| Apply { f; args; exact } ->
let cps = Var.Set.mem x ctx.Ctx.cps_calls in
let trampolined = Var.Set.mem x ctx.Ctx.trampolined_calls in
let args, prop, queue =
List.fold_right
~f:(fun x (args, prop, queue) ->
Expand All @@ -1039,7 +1040,7 @@ let rec translate_expr ctx queue loc x e level : _ * J.statement_list =
in
let (prop', f), queue = access_queue queue f in
let prop = or_p prop prop' in
let e = apply_fun ctx f args exact cps loc in
let e = apply_fun ctx f args exact trampolined loc in
(e, prop, queue), []
| Block (tag, a, array_or_not, _mut) ->
let contents, prop, queue =
Expand Down Expand Up @@ -1948,13 +1949,13 @@ let f
(p : Code.program)
~exported_runtime
~live_vars
~cps_calls
~trampolined_calls
~should_export
~warn_on_unhandled_effect
~deadcode_sentinal
debug =
let t' = Timer.make () in
let share = Share.get ~cps_calls ~alias_prims:exported_runtime p in
let share = Share.get ~trampolined_calls ~alias_prims:exported_runtime p in
let exported_runtime =
if exported_runtime then Some (Code.Var.fresh_n "runtime", ref false) else None
in
Expand All @@ -1970,7 +1971,7 @@ let f
~freevars
p.blocks
live_vars
cps_calls
trampolined_calls
share
debug
in
Expand Down
2 changes: 1 addition & 1 deletion compiler/lib/generate.mli
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ val f :
Code.program
-> exported_runtime:bool
-> live_vars:Deadcode.variable_uses
-> cps_calls:Effects.cps_calls
-> trampolined_calls:Effects.trampolined_calls
-> should_export:bool
-> warn_on_unhandled_effect:bool
-> deadcode_sentinal:Code.Var.t
Expand Down

0 comments on commit 285902b

Please sign in to comment.