@@ -400,7 +400,10 @@ let rec inline_function_statement propto adt fim Stmt.Fixed.{pattern; meta} =
400
400
( [inline_function_statement propto adt fim body]
401
401
@ map_no_loc s_upper )
402
402
; meta= Location_span. empty } ) } )
403
- | Profile (_ , l ) | Block l ->
403
+ | Profile (name , l ) ->
404
+ Profile
405
+ (name, List. map l ~f: (inline_function_statement propto adt fim))
406
+ | Block l ->
404
407
Block (List. map l ~f: (inline_function_statement propto adt fim))
405
408
| SList l ->
406
409
SList (List. map l ~f: (inline_function_statement propto adt fim))
@@ -535,7 +538,7 @@ let unroll_loop_one_step_statement _ =
535
538
( Expr.Fixed.
536
539
{ lower with
537
540
pattern=
538
- FunApp (StanLib (" Geq__" , FnPlain , SoA ), [upper; lower]) }
541
+ FunApp (StanLib (" Geq__" , FnPlain , AoS ), [upper; lower]) }
539
542
, { pattern=
540
543
(let body_unrolled =
541
544
subst_args_stmt [loopvar] [lower]
@@ -550,7 +553,7 @@ let unroll_loop_one_step_statement _ =
550
553
{ lower with
551
554
pattern=
552
555
FunApp
553
- ( StanLib (" Plus__" , FnPlain , SoA )
556
+ ( StanLib (" Plus__" , FnPlain , AoS )
554
557
, [lower; Expr.Helpers. loop_bottom] ) } }
555
558
; meta= Location_span. empty } in
556
559
match body_unrolled.pattern with
@@ -722,7 +725,10 @@ let dead_code_elimination (mir : Program.Typed.t) =
722
725
&& is_skip_break_continue body.pattern
723
726
then Skip
724
727
else For {loopvar; lower; upper; body}
725
- | Profile (_ , l ) | Block l ->
728
+ | Profile (name , l ) ->
729
+ let l' = List. filter ~f: (fun x -> x.pattern <> Skip ) l in
730
+ if List. length l' = 0 then Skip else Profile (name, l')
731
+ | Block l ->
726
732
let l' = List. filter ~f: (fun x -> x.pattern <> Skip ) l in
727
733
if List. length l' = 0 then Skip else Block l'
728
734
| SList l ->
@@ -1118,6 +1124,64 @@ let optimize_ad_levels (mir : Program.Typed.t) =
1118
1124
stmt in
1119
1125
transform_program_blockwise mir transform
1120
1126
1127
+ (* *
1128
+ * Deduces whether types can be Structures of Arrays (SoA/fast) or
1129
+ * Arrays of Structs (AoS/slow). See the docs in
1130
+ * Mem_pattern.query_demote_stmt/exprs* functions for
1131
+ * details on the rules surrounding when demotion from
1132
+ * SoA -> AoS needs to happen.
1133
+ *
1134
+ * This first does a simple iter over
1135
+ * the log_prob portion of the MIR, finding the names of all matrices
1136
+ * (and arrays of matrices) where either the Stan math function
1137
+ * does not support SoA or the object is single cell accesed within a
1138
+ * For or While loop. These are the initial variables
1139
+ * given to the monotone framework. Then log_prob has all matrix like objects
1140
+ * and the functions that use them to SoA. After that the
1141
+ * Monotone framework is used to deduce assignment paths of AoS <-> SoA
1142
+ * and vice versa which need to be demoted to AoS as well as updating
1143
+ * functions and objects after these assignment passes that then
1144
+ * also need to be AoS.
1145
+ *
1146
+ * @param mir: The program's whole MIR.
1147
+ *)
1148
+ let optimize_soa (mir : Program.Typed.t ) =
1149
+ let gen_aos_variables
1150
+ (flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t )
1151
+ (l : int ) (aos_variables : string Set.Poly.t ) =
1152
+ let mir_node mir_idx = Map. find_exn flowgraph_to_mir mir_idx in
1153
+ match (mir_node l).pattern with
1154
+ | stmt -> Mem_pattern. query_demotable_stmt aos_variables stmt in
1155
+ let initial_variables =
1156
+ List. fold ~init: Set.Poly. empty
1157
+ ~f: (Mem_pattern. query_initial_demotable_stmt false )
1158
+ mir.log_prob in
1159
+ (*
1160
+ let print_set s =
1161
+ Set.Poly.iter ~f:print_endline s in
1162
+ let () = print_set initial_variables in
1163
+ *)
1164
+ let mod_exprs aos_exits mod_expr =
1165
+ Mir_utils. map_rec_expr (Mem_pattern. modify_expr_pattern aos_exits) mod_expr
1166
+ in
1167
+ let modify_stmt_patt stmt_pattern variable_set =
1168
+ Mem_pattern. modify_stmt_pattern stmt_pattern variable_set in
1169
+ let transform stmt =
1170
+ optimize_minimal_variables ~gen_variables: gen_aos_variables
1171
+ ~update_expr: mod_exprs ~update_stmt: modify_stmt_patt ~initial_variables
1172
+ stmt ~extra_variables: (fun _ -> initial_variables) in
1173
+ let transform' s =
1174
+ match transform {pattern= SList s; meta= Location_span. empty} with
1175
+ | { pattern=
1176
+ SList (l : (Expr.Typed.Meta.t, Stmt.Located.Meta.t ) Stmt.Fixed. t list )
1177
+ ; _ } ->
1178
+ l
1179
+ | _ ->
1180
+ raise
1181
+ (Failure " Something went wrong with program transformation packing!" )
1182
+ in
1183
+ {mir with log_prob= transform' mir.log_prob}
1184
+
1121
1185
(* Apparently you need to completely copy/paste type definitions between
1122
1186
ml and mli files?*)
1123
1187
type optimization_settings =
@@ -1134,7 +1198,8 @@ type optimization_settings =
1134
1198
; partial_evaluation : bool
1135
1199
; lazy_code_motion : bool
1136
1200
; optimize_ad_levels : bool
1137
- ; preserve_stability : bool }
1201
+ ; preserve_stability : bool
1202
+ ; optimize_soa : bool }
1138
1203
1139
1204
let settings_const b =
1140
1205
{ function_inlining= b
@@ -1150,7 +1215,8 @@ let settings_const b =
1150
1215
; partial_evaluation= b
1151
1216
; lazy_code_motion= b
1152
1217
; optimize_ad_levels= b
1153
- ; preserve_stability= not b }
1218
+ ; preserve_stability= not b
1219
+ ; optimize_soa= b }
1154
1220
1155
1221
let all_optimizations : optimization_settings = settings_const true
1156
1222
let no_optimizations : optimization_settings = settings_const false
@@ -1159,7 +1225,7 @@ type optimization_level = O0 | O1 | Oexperimental
1159
1225
1160
1226
let level_optimizations (lvl : optimization_level ) : optimization_settings =
1161
1227
match lvl with
1162
- | O0 -> { no_optimizations with allow_uninitialized_decls = false }
1228
+ | O0 -> no_optimizations
1163
1229
| O1 ->
1164
1230
{ function_inlining= false
1165
1231
; static_loop_unrolling= false
@@ -1174,7 +1240,8 @@ let level_optimizations (lvl : optimization_level) : optimization_settings =
1174
1240
; lazy_code_motion= false
1175
1241
; allow_uninitialized_decls= false
1176
1242
; optimize_ad_levels= true
1177
- ; preserve_stability= false }
1243
+ ; preserve_stability= false
1244
+ ; optimize_soa= true }
1178
1245
| Oexperimental -> all_optimizations
1179
1246
1180
1247
let optimization_suite ?(settings = all_optimizations) mir =
@@ -1220,7 +1287,8 @@ let optimization_suite ?(settings = all_optimizations) mir =
1220
1287
; (optimize_ad_levels, settings.optimize_ad_levels)
1221
1288
(* Book: Machine idioms and instruction combining *)
1222
1289
(* Matthijs: Everything < block_fixing *)
1223
- ; (block_fixing, settings.block_fixing) ] in
1290
+ ; (block_fixing, settings.block_fixing)
1291
+ ; (optimize_soa, settings.optimize_soa) ] in
1224
1292
let optimizations =
1225
1293
List. filter_map maybe_optimizations ~f: (fun (fn , flag ) ->
1226
1294
if flag then Some fn else None ) in
0 commit comments