Skip to content

Improved optimization pass scheduling #1962

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
May 6, 2025
Merged
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
* Compiler: improve debug/sourcemap location of closures (#1947)
* Compiler: improve tailcall optimization (#1943)
* Runtime: use Dataview to convert between floats and bit representation
* Compiler: speed-up compilation by improving the scheduling of optimization passes (#1962)

## Bug fixes
* Compiler: fix stack overflow issues with double translation (#1869)
Expand Down
37 changes: 29 additions & 8 deletions compiler/lib/code.ml
Original file line number Diff line number Diff line change
Expand Up @@ -835,11 +835,35 @@ let check_updates ~name p1 p2 ~updates =
print_diff p1 p2;
assert false

let cont_equal (pc, args) (pc', args') = pc = pc' && List.equal ~eq:Var.equal args args'

let cont_compare (pc, args) (pc', args') =
let c = compare pc pc' in
if c <> 0 then c else List.compare ~cmp:Var.compare args args'

let with_invariant = Debug.find "invariant"

let check_defs = false

let invariant { blocks; start; _ } =
let used_blocks p =
let visited = BitSet.create' p.free_pc in
let rec mark_used pc =
if not (BitSet.mem visited pc)
then (
BitSet.set visited pc;
let block = Addr.Map.find pc p.blocks in
List.iter
~f:(fun i ->
match i with
| Let (_, Closure (_, (pc', _), _)) -> mark_used pc'
| _ -> ())
block.body;
fold_children p.blocks pc (fun pc' () -> mark_used pc') ())
in
mark_used p.start;
visited

let invariant ({ blocks; start; _ } as p) =
if with_invariant ()
then (
assert (Addr.Map.mem start blocks);
Expand Down Expand Up @@ -889,6 +913,7 @@ let invariant { blocks; start; _ } =
| Stop -> ()
| Branch cont -> check_cont cont
| Cond (_x, cont1, cont2) ->
assert (not (cont_equal cont1 cont2));
check_cont cont1;
check_cont cont2
| Switch (_x, a1) -> Array.iteri a1 ~f:(fun _ cont -> check_cont cont)
Expand All @@ -897,16 +922,12 @@ let invariant { blocks; start; _ } =
check_cont cont2
| Poptrap cont -> check_cont cont
in
let visited = used_blocks p in
Addr.Map.iter
(fun _pc block ->
(fun pc block ->
assert (BitSet.mem visited pc);
List.iter block.params ~f:define;
List.iter block.body ~f:check_instr;
check_events block.body;
check_last block.branch)
blocks)

let cont_equal (pc, args) (pc', args') = pc = pc' && List.equal ~eq:Var.equal args args'

let cont_compare (pc, args) (pc', args') =
let c = compare pc pc' in
if c <> 0 then c else List.compare ~cmp:Var.compare args args'
2 changes: 2 additions & 0 deletions compiler/lib/code.mli
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,8 @@ val traverse :
val preorder_traverse :
fold_blocs_poly -> (Addr.t -> 'c -> 'c) -> Addr.t -> block Addr.Map.t -> 'c -> 'c

val used_blocks : program -> BitSet.t

val prepend : program -> instr list -> program

val empty : program
Expand Down
128 changes: 80 additions & 48 deletions compiler/lib/deadcode.ml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ type t =
; mutable deleted_instrs : int
; mutable deleted_blocks : int
; mutable deleted_params : int
; mutable block_shortcut : int
}

(****)
Expand Down Expand Up @@ -185,6 +186,30 @@ let annot st pc xi =

(****)

let remove_unused_blocks' p =
let count = ref 0 in
let used = Code.used_blocks p in
let blocks =
Addr.Map.filter
(fun pc _ ->
let b = BitSet.mem used pc in
if not b then incr count;
b)
p.blocks
in
{ p with blocks }, !count

let remove_unused_blocks p =
let previous_p = p in
let t = Timer.make () in
let p, count = remove_unused_blocks' p in
if times () then Format.eprintf " dead block: %a@." Timer.print t;
if stats () then Format.eprintf "Stats - dead block: deleted %d@." count;
if debug_stats () then Code.check_updates ~name:"dead block" previous_p p ~updates:count;
p

(****)

let rec add_arg_dep defs params args =
match params, args with
| x :: params, y :: args ->
Expand All @@ -194,33 +219,32 @@ let rec add_arg_dep defs params args =
| _ -> assert false

let add_cont_dep blocks defs (pc, args) =
match try Some (Addr.Map.find pc blocks) with Not_found -> None with
| Some block -> add_arg_dep defs block.params args
| None -> () (* Dead continuation *)
let block = Addr.Map.find pc blocks in
add_arg_dep defs block.params args

let empty_body b =
match b with
| [] | [ Event _ ] -> true
| _ -> false

let remove_empty_blocks ~live_vars (p : Code.program) : Code.program =
let previous_p = p in
let t = Timer.make () in
let count = ref 0 in
let remove_empty_blocks st (p : Code.program) : Code.program =
let shortcuts = Hashtbl.create 16 in
let rec resolve_rec visited ((pc, args) as cont) =
if Addr.Set.mem pc visited
then cont
else
match Hashtbl.find_opt shortcuts pc with
| Some (params, cont) ->
incr count;
let pc', args' = resolve_rec (Addr.Set.add pc visited) cont in
let s = Subst.from_map (Subst.build_mapping params args) in
pc', List.map ~f:s args'
| None -> cont
in
let resolve cont = resolve_rec Addr.Set.empty cont in
let resolve cont =
let cont' = resolve_rec Addr.Set.empty cont in
if not (Code.cont_equal cont cont') then st.block_shortcut <- st.block_shortcut + 1;
cont'
in
Addr.Map.iter
(fun pc block ->
match block with
Expand All @@ -235,7 +259,7 @@ let remove_empty_blocks ~live_vars (p : Code.program) : Code.program =
used as argument to the continuation *)
if
List.for_all
~f:(fun x -> live_vars.(Var.idx x) = 1 && Var.Set.mem x args)
~f:(fun x -> st.live.(Var.idx x) = 1 && Var.Set.mem x args)
params
then Hashtbl.add shortcuts pc (params, cont)
| _ -> ())
Expand All @@ -248,20 +272,20 @@ let remove_empty_blocks ~live_vars (p : Code.program) : Code.program =
(let branch = block.branch in
match branch with
| Branch cont -> Branch (resolve cont)
| Cond (x, cont1, cont2) -> Cond (x, resolve cont1, resolve cont2)
| Cond (x, cont1, cont2) ->
let cont1' = resolve cont1 in
let cont2' = resolve cont2 in
if Code.cont_equal cont1' cont2'
then Branch cont1'
else Cond (x, cont1', cont2')
| Switch (x, a1) -> Switch (x, Array.map ~f:resolve a1)
| Pushtrap (cont1, x, cont2) -> Pushtrap (resolve cont1, x, resolve cont2)
| Poptrap cont -> Poptrap (resolve cont)
| Return _ | Raise _ | Stop -> branch)
})
p.blocks
in
let p = { p with blocks } in
if times () then Format.eprintf " dead code elim. empty blocks: %a@." Timer.print t;
if stats () then Format.eprintf "Stats - dead code empty blocks: %d@." !count;
if debug_stats ()
then Code.check_updates ~name:"emptyblock" previous_p p ~updates:!count;
p
{ p with blocks }

let f ({ blocks; _ } as p : Code.program) =
let previous_p = p in
Expand Down Expand Up @@ -299,52 +323,60 @@ let f ({ blocks; _ } as p : Code.program) =
; deleted_instrs = 0
; deleted_blocks = 0
; deleted_params = 0
; block_shortcut = 0
}
in
mark_reachable st p.start;
if debug () then Print.program Format.err_formatter (fun pc xi -> annot st pc xi) p;
let all_blocks = blocks in
let blocks =
Addr.Map.filter_map
(fun pc block ->
if not (BitSet.mem st.reachable_blocks pc)
then (
st.deleted_blocks <- st.deleted_blocks + 1;
None)
else
Some
{ params = List.filter block.params ~f:(fun x -> st.live.(Var.idx x) > 0)
; body =
List.fold_left block.body ~init:[] ~f:(fun acc i ->
match i, acc with
| Event _, Event _ :: prev ->
(* Avoid consecutive events (keep just the last one) *)
i :: prev
| _ ->
if live_instr st i
then filter_closure all_blocks st i :: acc
else (
st.deleted_instrs <- st.deleted_instrs + 1;
acc))
|> List.rev
; branch = filter_live_last all_blocks st block.branch
})
blocks
let p =
let all_blocks = blocks in
let blocks =
Addr.Map.filter_map
(fun pc block ->
if not (BitSet.mem st.reachable_blocks pc)
then (
st.deleted_blocks <- st.deleted_blocks + 1;
None)
else
Some
{ params = List.filter block.params ~f:(fun x -> st.live.(Var.idx x) > 0)
; body =
List.fold_left block.body ~init:[] ~f:(fun acc i ->
match i, acc with
| Event _, Event _ :: prev ->
(* Avoid consecutive events (keep just the last one) *)
i :: prev
| _ ->
if live_instr st i
then filter_closure all_blocks st i :: acc
else (
st.deleted_instrs <- st.deleted_instrs + 1;
acc))
|> List.rev
; branch = filter_live_last all_blocks st block.branch
})
blocks
in
{ p with blocks }
in
let p = { p with blocks } in
let p = remove_empty_blocks st p in
if times () then Format.eprintf " dead code elim.: %a@." Timer.print t;
if stats ()
then
Format.eprintf
"Stats - dead code: deleted %d instructions, %d blocks, %d parameters@."
"Stats - dead code: deleted %d instructions, %d blocks, %d parameters, %d \
branches@."
st.deleted_instrs
st.deleted_blocks
st.deleted_params;
st.deleted_params
st.block_shortcut;
if debug_stats ()
then
Code.check_updates
~name:"deadcode"
previous_p
p
~updates:(st.deleted_instrs + st.deleted_blocks + st.deleted_params);
~updates:
(st.deleted_instrs + st.deleted_blocks + st.deleted_params + st.block_shortcut);
let p = remove_unused_blocks p in
p, st.live
2 changes: 1 addition & 1 deletion compiler/lib/deadcode.mli
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ type variable_uses =

val f : Code.program -> Code.program * variable_uses

val remove_empty_blocks : live_vars:variable_uses -> Code.program -> Code.program
val remove_unused_blocks : Code.program -> Code.program
44 changes: 17 additions & 27 deletions compiler/lib/driver.ml
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ let deadcode' p =
Deadcode.f p

let deadcode p =
let r, live_vars = deadcode' p in
Deadcode.remove_empty_blocks ~live_vars r
let r, _ = deadcode' p in
r

let inline p =
if Config.Flag.inline () && Config.Flag.deadcode ()
Expand Down Expand Up @@ -102,8 +102,6 @@ let effects ~deadcode_sentinal p =
| `Cps | `Double_translation ->
if debug () then Format.eprintf "Effects...@.";
let p, live_vars = Deadcode.f p in
let p = Deadcode.remove_empty_blocks ~live_vars p in
let p, live_vars = Deadcode.f p in
let info = Global_flow.f ~fast:false p in
let p, live_vars =
if Config.Flag.globaldeadcode ()
Expand Down Expand Up @@ -144,51 +142,43 @@ let print p =
if debug () then Code.Print.program Format.err_formatter (fun _ _ -> "") p;
p

let stats = Debug.find "stats"

let rec loop max name round i (p : 'a) : 'a =
if times () then Format.eprintf "%s#%d...@." name i;
let p' = round p in
let debug = times () || stats () in
if debug then Format.eprintf "%s#%d...@." name i;
let p' = round ~first:(i = 1) p in
if i >= max
then (
if times () then Format.eprintf "%s#%d: couldn't reach fix point.@." name i;
if debug then Format.eprintf "%s#%d: couldn't reach fix point.@." name i;
p')
else if Code.equal p' p
then (
if times () then Format.eprintf "%s#%d: fix-point reached.@." name i;
if debug then Format.eprintf "%s#%d: fix-point reached.@." name i;
p')
else loop max name round (i + 1) p'

(* o1 *)

let o1 : 'a -> 'a =
let round ~first : 'a -> 'a =
print
+> tailcall
+> flow
+> specialize
+> eval
+> inline (* inlining may reveal new tailcall opt *)
+> deadcode
+> tailcall
+> phi
+> (if first then Fun.id else phi)
+> flow
+> specialize
+> eval
+> inline
+> deadcode
+> print
+> flow
+> specialize
+> eval
+> inline
+> deadcode
+> phi

(* o1 *)

let o1 = loop 2 "round" round 1 +> phi +> flow +> specialize +> eval +> print

(* o2 *)

let o2 = loop 10 "o1" o1 1 +> print
let o2 = loop 10 "round" round 1 +> print

(* o3 *)

let o3 = loop 10 "o1" o1 1 +> print
let o3 = loop 30 "round" round 1 +> print

let generate
~exported_runtime
Expand Down
Loading
Loading