Skip to content

Commit

Permalink
Replace static loop unrolling with dynamic
Browse files Browse the repository at this point in the history
  • Loading branch information
karoliineh committed Nov 6, 2024
1 parent 0e1cac9 commit 040c60b
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 17 deletions.
12 changes: 8 additions & 4 deletions src/framework/constraints.ml
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ struct
let exit_loop max_iter x l : LoopCounts.t list =
List.init (max_iter + 1) (fun i -> LoopCounts.add x i l)

let unroll (v,(c,l)) (edges, u) max_iter =
let unroll (v,(c,l)) (edges, u) =
let open GobList.Syntax in
let u_heads = NodeH.find_default loop_heads u NodeSet.empty in
let v_heads = NodeH.find_default loop_heads v NodeSet.empty in
Expand All @@ -355,7 +355,10 @@ struct
(* For each node 'u' within a loop from where the loop is exited,
add loop counts from 0 up to the max nr of unroll iterations (max_iter).
For nested loops we have to add (combinations of) loop counts for all of the nests. *)
let ls = NodeSet.fold (fun exit ls -> List.concat_map (exit_loop max_iter exit) ls) exits [l] in
let ls = NodeSet.fold (fun exit ls ->
let unroll_factor = NodeH.find_default LoopUnrolling.factorH exit 0 in
List.concat_map (exit_loop unroll_factor exit) ls
) exits [l] in
(* For each node 'u' that is not in the same loop as node 'v',
i.e. the loop is entered for the first time from 'u',
if loop counts have reached 0, remove the loop counts to take the entry edge. *)
Expand All @@ -365,7 +368,8 @@ struct
- If the loop count includes max_iter, keep it to represent any remaining loop iterations that were not unrolled.
- If loop counts has reached 0 for 'v', stop calculating further loop counts, as unrolling stops and loop entry edge must be taken instead. *)
if is_back_edge then
List.concat_map (back_edge max_iter v) ls
let unroll_factor = NodeH.find_default LoopUnrolling.factorH v 0 in
List.concat_map (back_edge unroll_factor v) ls
else
ls

Expand Down Expand Up @@ -408,7 +412,7 @@ struct
let _, locs = List.fold_right (fun (f,e) (t,xs) -> f, (f,t)::xs) edges (Node.location v,[]) in
let res = List.fold_left2 (|>) pval (List.map (tf (v,(Obj.repr (fun () -> c),l)) getl sidel getg sideg u) edges) locs in
S.D.join acc res
) (S.D.bot ()) (unroll (v,(c,l)) (edges, u) 10) (* TODO: value hardcoded *)
) (S.D.bot ()) (unroll (v,(c,l)) (edges, u))

let tf (v,(c,l)) (e,u) getl sidel getg sideg =
let old_node = !current_node in
Expand Down
5 changes: 2 additions & 3 deletions src/util/cilCfg.ml
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,9 @@ let createCFG (fileAST: file) =
iterGlobals fileAST (fun glob ->
match glob with
| GFun(fd,_) ->
(* before prepareCfg so continues still appear as such *)
if (get_int "exp.unrolling-factor")>0 || AutoTune0.isActivated "loopUnrollHeuristic" then LoopUnrolling.unroll_loops fd loops;
prepareCFG fd;
computeCFGInfo fd true
computeCFGInfo fd true;
if (get_int "exp.unrolling-factor")>0 || AutoTune0.isActivated "loopUnrollHeuristic" then LoopUnrolling.unroll_loops fd loops;
| _ -> ()
);
if get_bool "dbg.run_cil_check" then assert (Check.checkFile [] fileAST);
14 changes: 4 additions & 10 deletions src/util/loopUnrolling.ml
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,8 @@ let copy_and_patch_labels break_target current_continue_target stmts =
let patchLabelsVisitor = new patchLabelsGotosVisitor(StatementHashTable.find_opt gotos) in
List.map (visitCilStmt patchLabelsVisitor) stmts'

let factorH = MyCFG.NodeH.create 100

class loopUnrollingVisitor (func, totalLoops) = object
(* Labels are simply handled by giving them a fresh name. Jumps coming from outside will still always go to the original label! *)
inherit nopCilVisitor
Expand All @@ -455,18 +457,10 @@ class loopUnrollingVisitor (func, totalLoops) = object
nests <- nests - 1; Logs.debug "nests: %i" nests;
let factor = loop_unrolling_factor stmt func totalLoops in
if factor > 0 then (
MyCFG.NodeH.add factorH (Statement (fst (CfgTools.find_real_stmt stmt))) factor;
Logs.info "unrolling loop at %a with factor %d" CilType.Location.pretty loc factor;
annotateArrays b;
(* top-level breaks should immediately go to the end of the loop, and not just break out of the current iteration *)
let break_target = { (Cil.mkEmptyStmt ()) with labels = [Label (Cil.freshLabel "loop_end",loc, false)]} in
let copies = List.init factor (fun i ->
(* continues should go to the next unrolling *)
let current_continue_target = { (Cil.mkEmptyStmt ()) with labels = [Label (Cil.freshLabel ("loop_continue_" ^ (string_of_int i)),loc, false)]} in
let one_copy_stmts = copy_and_patch_labels break_target current_continue_target b.bstmts in
one_copy_stmts @ [current_continue_target]
)
in
mkStmt (Block (mkBlock (List.flatten copies @ [stmt; break_target])))
stmt
) else stmt (*no change*)
| _ -> stmt
in
Expand Down
2 changes: 2 additions & 0 deletions src/util/loopUnrolling.mli
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ val unroll_loops: GoblintCil.fundec -> int -> unit

val find_original: GoblintCil.stmt -> GoblintCil.stmt
(** Find original un-unrolled instance of the statement. *)

val factorH : int CfgTools.NH.t

0 comments on commit 040c60b

Please sign in to comment.