Skip to content

Commit c58fcc1

Browse files
authored
Merge pull request #955 from SteveBronder/feature/soa-optim
Add optims to detect when SoA matrices can be used
2 parents 92092ea + 8d068b4 commit c58fcc1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+32423
-22581
lines changed

src/analysis_and_optimization/Mem_pattern.ml

Lines changed: 636 additions & 0 deletions
Large diffs are not rendered by default.

src/analysis_and_optimization/Mir_utils.ml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,15 @@ let fwd_traverse_statement stmt ~init ~f =
200200
| For vars ->
201201
let s', c = f init vars.body in
202202
(s', For {vars with body= c})
203-
| Profile (_, stmts) | Block stmts ->
203+
| Profile (name, stmts) ->
204+
let s', ls =
205+
List.fold_left stmts
206+
~f:(fun (s, l) stmt ->
207+
let s', c = f s stmt in
208+
(s', List.cons c l) )
209+
~init:(init, []) in
210+
(s', Profile (name, List.rev ls))
211+
| Block stmts ->
204212
let s', ls =
205213
List.fold_left stmts
206214
~f:(fun (s, l) stmt ->

src/analysis_and_optimization/Monotone_framework.ml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -770,6 +770,7 @@ let minimal_variables_fwd_transfer
770770
let gen = gen_variable flowgraph_to_mir l p in
771771
let kill =
772772
match mir_node with
773+
(* This probably isn't necessary because Stan doesn't allow shadowing, right? *)
773774
| Decl {decl_id; decl_adtype= DataOnly; _} -> Set.Poly.singleton decl_id
774775
| _ -> Set.Poly.empty in
775776
transfer_gen_kill p gen kill end : TRANSFER_FUNCTION

src/analysis_and_optimization/Optimize.ml

Lines changed: 77 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,10 @@ let rec inline_function_statement propto adt fim Stmt.Fixed.{pattern; meta} =
400400
( [inline_function_statement propto adt fim body]
401401
@ map_no_loc s_upper )
402402
; meta= Location_span.empty } ) } )
403-
| Profile (_, l) | Block l ->
403+
| Profile (name, l) ->
404+
Profile
405+
(name, List.map l ~f:(inline_function_statement propto adt fim))
406+
| Block l ->
404407
Block (List.map l ~f:(inline_function_statement propto adt fim))
405408
| SList l ->
406409
SList (List.map l ~f:(inline_function_statement propto adt fim))
@@ -535,7 +538,7 @@ let unroll_loop_one_step_statement _ =
535538
( Expr.Fixed.
536539
{ lower with
537540
pattern=
538-
FunApp (StanLib ("Geq__", FnPlain, SoA), [upper; lower]) }
541+
FunApp (StanLib ("Geq__", FnPlain, AoS), [upper; lower]) }
539542
, { pattern=
540543
(let body_unrolled =
541544
subst_args_stmt [loopvar] [lower]
@@ -550,7 +553,7 @@ let unroll_loop_one_step_statement _ =
550553
{ lower with
551554
pattern=
552555
FunApp
553-
( StanLib ("Plus__", FnPlain, SoA)
556+
( StanLib ("Plus__", FnPlain, AoS)
554557
, [lower; Expr.Helpers.loop_bottom] ) } }
555558
; meta= Location_span.empty } in
556559
match body_unrolled.pattern with
@@ -722,7 +725,10 @@ let dead_code_elimination (mir : Program.Typed.t) =
722725
&& is_skip_break_continue body.pattern
723726
then Skip
724727
else For {loopvar; lower; upper; body}
725-
| Profile (_, l) | Block l ->
728+
| Profile (name, l) ->
729+
let l' = List.filter ~f:(fun x -> x.pattern <> Skip) l in
730+
if List.length l' = 0 then Skip else Profile (name, l')
731+
| Block l ->
726732
let l' = List.filter ~f:(fun x -> x.pattern <> Skip) l in
727733
if List.length l' = 0 then Skip else Block l'
728734
| SList l ->
@@ -1118,6 +1124,64 @@ let optimize_ad_levels (mir : Program.Typed.t) =
11181124
stmt in
11191125
transform_program_blockwise mir transform
11201126

1127+
(**
1128+
* Deduces whether types can be Structures of Arrays (SoA/fast) or
1129+
* Arrays of Structs (AoS/slow). See the docs in
1130+
* Mem_pattern.query_demote_stmt/exprs* functions for
1131+
* details on the rules surrounding when demotion from
1132+
* SoA -> AoS needs to happen.
1133+
*
1134+
* This first does a simple iter over
1135+
* the log_prob portion of the MIR, finding the names of all matrices
1136+
* (and arrays of matrices) where either the Stan math function
1137+
* does not support SoA or the object is single cell accesed within a
1138+
* For or While loop. These are the initial variables
1139+
* given to the monotone framework. Then log_prob has all matrix like objects
1140+
* and the functions that use them to SoA. After that the
1141+
* Monotone framework is used to deduce assignment paths of AoS <-> SoA
1142+
* and vice versa which need to be demoted to AoS as well as updating
1143+
* functions and objects after these assignment passes that then
1144+
* also need to be AoS.
1145+
*
1146+
* @param mir: The program's whole MIR.
1147+
*)
1148+
let optimize_soa (mir : Program.Typed.t) =
1149+
let gen_aos_variables
1150+
(flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t)
1151+
(l : int) (aos_variables : string Set.Poly.t) =
1152+
let mir_node mir_idx = Map.find_exn flowgraph_to_mir mir_idx in
1153+
match (mir_node l).pattern with
1154+
| stmt -> Mem_pattern.query_demotable_stmt aos_variables stmt in
1155+
let initial_variables =
1156+
List.fold ~init:Set.Poly.empty
1157+
~f:(Mem_pattern.query_initial_demotable_stmt false)
1158+
mir.log_prob in
1159+
(*
1160+
let print_set s =
1161+
Set.Poly.iter ~f:print_endline s in
1162+
let () = print_set initial_variables in
1163+
*)
1164+
let mod_exprs aos_exits mod_expr =
1165+
Mir_utils.map_rec_expr (Mem_pattern.modify_expr_pattern aos_exits) mod_expr
1166+
in
1167+
let modify_stmt_patt stmt_pattern variable_set =
1168+
Mem_pattern.modify_stmt_pattern stmt_pattern variable_set in
1169+
let transform stmt =
1170+
optimize_minimal_variables ~gen_variables:gen_aos_variables
1171+
~update_expr:mod_exprs ~update_stmt:modify_stmt_patt ~initial_variables
1172+
stmt ~extra_variables:(fun _ -> initial_variables) in
1173+
let transform' s =
1174+
match transform {pattern= SList s; meta= Location_span.empty} with
1175+
| { pattern=
1176+
SList (l : (Expr.Typed.Meta.t, Stmt.Located.Meta.t) Stmt.Fixed.t list)
1177+
; _ } ->
1178+
l
1179+
| _ ->
1180+
raise
1181+
(Failure "Something went wrong with program transformation packing!")
1182+
in
1183+
{mir with log_prob= transform' mir.log_prob}
1184+
11211185
(* Apparently you need to completely copy/paste type definitions between
11221186
ml and mli files?*)
11231187
type optimization_settings =
@@ -1134,7 +1198,8 @@ type optimization_settings =
11341198
; partial_evaluation: bool
11351199
; lazy_code_motion: bool
11361200
; optimize_ad_levels: bool
1137-
; preserve_stability: bool }
1201+
; preserve_stability: bool
1202+
; optimize_soa: bool }
11381203

11391204
let settings_const b =
11401205
{ function_inlining= b
@@ -1150,7 +1215,8 @@ let settings_const b =
11501215
; partial_evaluation= b
11511216
; lazy_code_motion= b
11521217
; optimize_ad_levels= b
1153-
; preserve_stability= not b }
1218+
; preserve_stability= not b
1219+
; optimize_soa= b }
11541220

11551221
let all_optimizations : optimization_settings = settings_const true
11561222
let no_optimizations : optimization_settings = settings_const false
@@ -1159,7 +1225,7 @@ type optimization_level = O0 | O1 | Oexperimental
11591225

11601226
let level_optimizations (lvl : optimization_level) : optimization_settings =
11611227
match lvl with
1162-
| O0 -> {no_optimizations with allow_uninitialized_decls= false}
1228+
| O0 -> no_optimizations
11631229
| O1 ->
11641230
{ function_inlining= false
11651231
; static_loop_unrolling= false
@@ -1174,7 +1240,8 @@ let level_optimizations (lvl : optimization_level) : optimization_settings =
11741240
; lazy_code_motion= false
11751241
; allow_uninitialized_decls= false
11761242
; optimize_ad_levels= true
1177-
; preserve_stability= false }
1243+
; preserve_stability= false
1244+
; optimize_soa= true }
11781245
| Oexperimental -> all_optimizations
11791246

11801247
let optimization_suite ?(settings = all_optimizations) mir =
@@ -1220,7 +1287,8 @@ let optimization_suite ?(settings = all_optimizations) mir =
12201287
; (optimize_ad_levels, settings.optimize_ad_levels)
12211288
(* Book: Machine idioms and instruction combining *)
12221289
(* Matthijs: Everything < block_fixing *)
1223-
; (block_fixing, settings.block_fixing) ] in
1290+
; (block_fixing, settings.block_fixing)
1291+
; (optimize_soa, settings.optimize_soa) ] in
12241292
let optimizations =
12251293
List.filter_map maybe_optimizations ~f:(fun (fn, flag) ->
12261294
if flag then Some fn else None ) in

src/analysis_and_optimization/Optimize.mli

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ type optimization_settings =
7878
; partial_evaluation: bool
7979
; lazy_code_motion: bool
8080
; optimize_ad_levels: bool
81-
; preserve_stability: bool }
81+
; preserve_stability: bool
82+
; optimize_soa: bool }
8283

8384
val all_optimizations : optimization_settings
8485
val no_optimizations : optimization_settings

src/frontend/Ast_to_Mir.ml

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -187,14 +187,6 @@ type transform_action = Check | Constrain | Unconstrain | IgnoreTransform
187187
type decl_context =
188188
{transform_action: transform_action; dadlevel: UnsizedType.autodifftype}
189189

190-
let constraint_forl = function
191-
| Transformation.Identity | Offset _ | Multiplier _ | OffsetMultiplier _
192-
|Lower _ | Upper _ | LowerUpper _ ->
193-
Stmt.Helpers.for_scalar
194-
| Ordered | PositiveOrdered | Simplex | UnitVector | CholeskyCorr
195-
|CholeskyCov | Correlation | Covariance ->
196-
Stmt.Helpers.for_eigen
197-
198190
let same_shape decl_id decl_var id var meta =
199191
if UnsizedType.is_scalar_type (Expr.Typed.type_of var) then []
200192
else
@@ -291,16 +283,7 @@ let param_size transform sizedtype =
291283
(fun k -> Expr.Helpers.(binop k Plus (k_choose_2 k)))
292284
sizedtype
293285

294-
let remove_possibly_exn pst action loc =
295-
match pst with
296-
| Type.Sized st -> st
297-
| Unsized _ ->
298-
Common.FatalError.fatal_error_msg
299-
[%message
300-
"Error extracting sizedtype" ~action ~loc:(loc : Location_span.t)]
301-
302286
let rec check_decl var decl_type' decl_id decl_trans smeta adlevel =
303-
let decl_type = remove_possibly_exn decl_type' "check" smeta in
304287
match decl_trans with
305288
| Transformation.LowerUpper (lb, ub) ->
306289
check_decl var decl_type' decl_id (Lower lb) smeta adlevel
@@ -312,7 +295,7 @@ let rec check_decl var decl_type' decl_id decl_trans smeta adlevel =
312295
Stmt.Helpers.internal_nrfunapp
313296
(FnCheck {trans= decl_trans; var_name; var= id})
314297
args smeta in
315-
[(constraint_forl decl_trans) decl_type check_id var smeta]
298+
[check_id var]
316299
| _ -> []
317300

318301
let check_sizedtype name =
@@ -687,7 +670,7 @@ let trans_block ud_dists declc block prog =
687670
check_transform_shape decl_id decl_var smeta.loc transform
688671
| Check ->
689672
check_transform_shape decl_id decl_var smeta.loc transform
690-
@ check_decl decl_var (Sized type_) decl_id transform
673+
@ check_decl decl_var (Type.Sized type_) decl_id transform
691674
smeta.loc declc.dadlevel
692675
| IgnoreTransform -> [] in
693676
(decl :: rhs_assignment) @ constrain_checks

src/middle/Expr.ml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ module Typed = struct
8888
let type_of Fixed.{meta= Meta.{type_; _}; _} = type_
8989
let loc_of Fixed.{meta= Meta.{loc; _}; _} = loc
9090
let adlevel_of Fixed.{meta= Meta.{adlevel; _}; _} = adlevel
91+
let fun_arg Fixed.{meta= Meta.{type_; adlevel; _}; _} = (adlevel, type_)
9192
end
9293

9394
(** Expressions with associated location, type and label *)

src/middle/Expr.mli

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ module Typed : sig
5151
val type_of : t -> UnsizedType.t
5252
val loc_of : t -> Location_span.t
5353
val adlevel_of : t -> UnsizedType.autodifftype
54+
val fun_arg : t -> UnsizedType.autodifftype * UnsizedType.t
5455
end
5556

5657
module Labelled : sig

src/middle/Index.ml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,18 @@ let bounds = function
3636
@param op a functor to run with inputs of inner exprs
3737
@param ind the Index.t to
3838
*)
39-
let apply ~default ~merge op ind =
39+
let apply ~default ~merge op (ind : 'a t) =
4040
match ind with
4141
| All -> default
4242
| Single ind_expr -> op ind_expr
4343
| Upfrom ind_expr -> op ind_expr
4444
| Between (expr_top, expr_bottom) -> merge (op expr_top) (op expr_bottom)
4545
| MultiIndex exprs -> op exprs
46+
47+
let folder (acc : string Set.Poly.t) op (ind : 'a t) : string Set.Poly.t =
48+
match ind with
49+
| All -> acc
50+
| Single ind_expr | Upfrom ind_expr | MultiIndex ind_expr -> op acc ind_expr
51+
| Between (expr_top, expr_bottom) ->
52+
let top_fold = op acc expr_top in
53+
Set.Poly.union top_fold (op top_fold expr_bottom)

0 commit comments

Comments
 (0)