19
19
//! When unrolling ACIR code, we remove reference count instructions because they are
20
20
//! only used by Brillig bytecode.
21
21
use acvm:: { acir:: AcirField , FieldElement } ;
22
+ use im:: HashSet ;
22
23
23
24
use crate :: {
25
+ brillig:: brillig_gen:: convert_ssa_function,
24
26
errors:: RuntimeError ,
25
27
ssa:: {
26
28
ir:: {
@@ -37,38 +39,60 @@ use crate::{
37
39
ssa_gen:: Ssa ,
38
40
} ,
39
41
} ;
40
- use fxhash:: { FxHashMap as HashMap , FxHashSet as HashSet } ;
42
+ use fxhash:: FxHashMap as HashMap ;
41
43
42
44
impl Ssa {
43
45
/// Loop unrolling can return errors, since ACIR functions need to be fully unrolled.
44
46
/// This meta-pass will keep trying to unroll loops and simplifying the SSA until no more errors are found.
45
- #[ tracing:: instrument( level = "trace" , skip( ssa) ) ]
46
- pub ( crate ) fn unroll_loops_iteratively ( mut ssa : Ssa ) -> Result < Ssa , RuntimeError > {
47
- for ( _, function) in ssa. functions . iter_mut ( ) {
47
+ ///
48
+ /// The `max_bytecode_incr_pct`, when given, is used to limit the growth of the Brillig bytecode size
49
+ /// after unrolling small loops to some percentage of the original loop. For example a value of 150 would
50
+ /// mean the new loop can be 150% (ie. 2.5 times) larger than the original loop. It will still contain
51
+ /// fewer SSA instructions, but that can still result in more Brillig opcodes.
52
+ #[ tracing:: instrument( level = "trace" , skip( self ) ) ]
53
+ pub ( crate ) fn unroll_loops_iteratively (
54
+ mut self : Ssa ,
55
+ max_bytecode_increase_percent : Option < i32 > ,
56
+ ) -> Result < Ssa , RuntimeError > {
57
+ for ( _, function) in self . functions . iter_mut ( ) {
58
+ // Take a snapshot of the function to compare byte size increase,
59
+ // but only if the setting indicates we have to, otherwise skip it.
60
+ let orig_func_and_max_incr_pct = max_bytecode_increase_percent
61
+ . filter ( |_| function. runtime ( ) . is_brillig ( ) )
62
+ . map ( |max_incr_pct| ( function. clone ( ) , max_incr_pct) ) ;
63
+
48
64
// Try to unroll loops first:
49
- let mut unroll_errors = function. try_unroll_loops ( ) ;
65
+ let ( mut has_unrolled , mut unroll_errors) = function. try_unroll_loops ( ) ;
50
66
51
67
// Keep unrolling until no more errors are found
52
68
while !unroll_errors. is_empty ( ) {
53
69
let prev_unroll_err_count = unroll_errors. len ( ) ;
54
70
55
71
// Simplify the SSA before retrying
56
-
57
- // Do a mem2reg after the last unroll to aid simplify_cfg
58
- function. mem2reg ( ) ;
59
- function. simplify_function ( ) ;
60
- // Do another mem2reg after simplify_cfg to aid the next unroll
61
- function. mem2reg ( ) ;
72
+ simplify_between_unrolls ( function) ;
62
73
63
74
// Unroll again
64
- unroll_errors = function. try_unroll_loops ( ) ;
75
+ let ( new_unrolled, new_errors) = function. try_unroll_loops ( ) ;
76
+ unroll_errors = new_errors;
77
+ has_unrolled |= new_unrolled;
78
+
65
79
// If we didn't manage to unroll any more loops, exit
66
80
if unroll_errors. len ( ) >= prev_unroll_err_count {
67
81
return Err ( unroll_errors. swap_remove ( 0 ) ) ;
68
82
}
69
83
}
84
+
85
+ if has_unrolled {
86
+ if let Some ( ( orig_function, max_incr_pct) ) = orig_func_and_max_incr_pct {
87
+ let new_size = brillig_bytecode_size ( function) ;
88
+ let orig_size = brillig_bytecode_size ( & orig_function) ;
89
+ if !is_new_size_ok ( orig_size, new_size, max_incr_pct) {
90
+ * function = orig_function;
91
+ }
92
+ }
93
+ }
70
94
}
71
- Ok ( ssa )
95
+ Ok ( self )
72
96
}
73
97
}
74
98
@@ -77,7 +101,7 @@ impl Function {
77
101
// This can also be true for ACIR, but we have no alternative to unrolling in ACIR.
78
102
// Brillig also generally prefers smaller code rather than faster code,
79
103
// so we only attempt to unroll small loops, which we decide on a case-by-case basis.
80
- fn try_unroll_loops ( & mut self ) -> Vec < RuntimeError > {
104
+ fn try_unroll_loops ( & mut self ) -> ( bool , Vec < RuntimeError > ) {
81
105
Loops :: find_all ( self ) . unroll_each ( self )
82
106
}
83
107
}
@@ -170,8 +194,10 @@ impl Loops {
170
194
171
195
/// Unroll all loops within a given function.
172
196
/// Any loops which fail to be unrolled (due to using non-constant indices) will be unmodified.
173
- fn unroll_each ( mut self , function : & mut Function ) -> Vec < RuntimeError > {
197
+ /// Returns whether any blocks have been modified
198
+ fn unroll_each ( mut self , function : & mut Function ) -> ( bool , Vec < RuntimeError > ) {
174
199
let mut unroll_errors = vec ! [ ] ;
200
+ let mut has_unrolled = false ;
175
201
while let Some ( next_loop) = self . yet_to_unroll . pop ( ) {
176
202
if function. runtime ( ) . is_brillig ( ) && !next_loop. is_small_loop ( function, & self . cfg ) {
177
203
continue ;
@@ -181,21 +207,25 @@ impl Loops {
181
207
if next_loop. blocks . iter ( ) . any ( |block| self . modified_blocks . contains ( block) ) {
182
208
let mut new_loops = Self :: find_all ( function) ;
183
209
new_loops. failed_to_unroll = self . failed_to_unroll ;
184
- return unroll_errors. into_iter ( ) . chain ( new_loops. unroll_each ( function) ) . collect ( ) ;
210
+ let ( new_unrolled, new_errors) = new_loops. unroll_each ( function) ;
211
+ return ( has_unrolled || new_unrolled, [ unroll_errors, new_errors] . concat ( ) ) ;
185
212
}
186
213
187
214
// Don't try to unroll the loop again if it is known to fail
188
215
if !self . failed_to_unroll . contains ( & next_loop. header ) {
189
216
match next_loop. unroll ( function, & self . cfg ) {
190
- Ok ( _) => self . modified_blocks . extend ( next_loop. blocks ) ,
217
+ Ok ( _) => {
218
+ has_unrolled = true ;
219
+ self . modified_blocks . extend ( next_loop. blocks ) ;
220
+ }
191
221
Err ( call_stack) => {
192
222
self . failed_to_unroll . insert ( next_loop. header ) ;
193
223
unroll_errors. push ( RuntimeError :: UnknownLoopBound { call_stack } ) ;
194
224
}
195
225
}
196
226
}
197
227
}
198
- unroll_errors
228
+ ( has_unrolled , unroll_errors)
199
229
}
200
230
}
201
231
@@ -947,21 +977,59 @@ impl<'f> LoopIteration<'f> {
947
977
}
948
978
}
949
979
980
+ /// Unrolling leaves some duplicate instructions which can potentially be removed.
981
+ fn simplify_between_unrolls ( function : & mut Function ) {
982
+ // Do a mem2reg after the last unroll to aid simplify_cfg
983
+ function. mem2reg ( ) ;
984
+ function. simplify_function ( ) ;
985
+ // Do another mem2reg after simplify_cfg to aid the next unroll
986
+ function. mem2reg ( ) ;
987
+ }
988
+
989
+ /// Convert the function to Brillig bytecode and return the resulting size.
990
+ fn brillig_bytecode_size ( function : & Function ) -> usize {
991
+ // We need to do some SSA passes in order for the conversion to be able to go ahead,
992
+ // otherwise we can hit `unreachable!()` instructions in `convert_ssa_instruction`.
993
+ // Creating a clone so as not to modify the originals.
994
+ let mut temp = function. clone ( ) ;
995
+
996
+ // Might as well give it the best chance.
997
+ simplify_between_unrolls ( & mut temp) ;
998
+
999
+ // This is to try to prevent hitting ICE.
1000
+ temp. dead_instruction_elimination ( false ) ;
1001
+
1002
+ convert_ssa_function ( & temp, false ) . byte_code . len ( )
1003
+ }
1004
+
1005
+ /// Decide if the new bytecode size is acceptable, compared to the original.
1006
+ ///
1007
+ /// The maximum increase can be expressed as a negative value if we demand a decrease.
1008
+ /// (Values -100 and under mean the new size should be 0).
1009
+ fn is_new_size_ok ( orig_size : usize , new_size : usize , max_incr_pct : i32 ) -> bool {
1010
+ let max_size_pct = 100i32 . saturating_add ( max_incr_pct) . max ( 0 ) as usize ;
1011
+ let max_size = orig_size. saturating_mul ( max_size_pct) ;
1012
+ new_size. saturating_mul ( 100 ) <= max_size
1013
+ }
1014
+
950
1015
#[ cfg( test) ]
951
1016
mod tests {
952
1017
use acvm:: FieldElement ;
1018
+ use test_case:: test_case;
953
1019
954
1020
use crate :: errors:: RuntimeError ;
955
1021
use crate :: ssa:: { ir:: value:: ValueId , opt:: assert_normalized_ssa_equals, Ssa } ;
956
1022
957
- use super :: { BoilerplateStats , Loops } ;
1023
+ use super :: { is_new_size_ok , BoilerplateStats , Loops } ;
958
1024
959
- /// Tries to unroll all loops in each SSA function.
1025
+ /// Tries to unroll all loops in each SSA function once, calling the `Function` directly,
1026
+ /// bypassing the iterative loop done by the SSA which does further optimisations.
1027
+ ///
960
1028
/// If any loop cannot be unrolled, it is left as-is or in a partially unrolled state.
961
1029
fn try_unroll_loops ( mut ssa : Ssa ) -> ( Ssa , Vec < RuntimeError > ) {
962
1030
let mut errors = vec ! [ ] ;
963
1031
for function in ssa. functions . values_mut ( ) {
964
- errors. extend ( function. try_unroll_loops ( ) ) ;
1032
+ errors. extend ( function. try_unroll_loops ( ) . 1 ) ;
965
1033
}
966
1034
( ssa, errors)
967
1035
}
@@ -1221,9 +1289,26 @@ mod tests {
1221
1289
1222
1290
let ( ssa, errors) = try_unroll_loops ( ssa) ;
1223
1291
assert_eq ! ( errors. len( ) , 0 , "Unroll should have no errors" ) ;
1292
+ // Check that it's still the original
1224
1293
assert_normalized_ssa_equals ( ssa, parse_ssa ( ) . to_string ( ) . as_str ( ) ) ;
1225
1294
}
1226
1295
1296
+ #[ test]
1297
+ fn test_brillig_unroll_iteratively_respects_max_increase ( ) {
1298
+ let ssa = brillig_unroll_test_case ( ) ;
1299
+ let ssa = ssa. unroll_loops_iteratively ( Some ( -90 ) ) . unwrap ( ) ;
1300
+ // Check that it's still the original
1301
+ assert_normalized_ssa_equals ( ssa, brillig_unroll_test_case ( ) . to_string ( ) . as_str ( ) ) ;
1302
+ }
1303
+
1304
+ #[ test]
1305
+ fn test_brillig_unroll_iteratively_with_large_max_increase ( ) {
1306
+ let ssa = brillig_unroll_test_case ( ) ;
1307
+ let ssa = ssa. unroll_loops_iteratively ( Some ( 50 ) ) . unwrap ( ) ;
1308
+ // Check that it did the unroll
1309
+ assert_eq ! ( ssa. main( ) . reachable_blocks( ) . len( ) , 2 , "The loop should be unrolled" ) ;
1310
+ }
1311
+
1227
1312
/// Test that `break` and `continue` stop unrolling without any panic.
1228
1313
#[ test]
1229
1314
fn test_brillig_unroll_break_and_continue ( ) {
@@ -1377,4 +1462,14 @@ mod tests {
1377
1462
let loop0 = loops. yet_to_unroll . pop ( ) . expect ( "there should be a loop" ) ;
1378
1463
loop0. boilerplate_stats ( function, & loops. cfg ) . expect ( "there should be stats" )
1379
1464
}
1465
+
1466
+ #[ test_case( 1000 , 700 , 50 , true ; "size decreased" ) ]
1467
+ #[ test_case( 1000 , 1500 , 50 , true ; "size increased just by the max" ) ]
1468
+ #[ test_case( 1000 , 1501 , 50 , false ; "size increased over the max" ) ]
1469
+ #[ test_case( 1000 , 700 , -50 , false ; "size decreased but not enough" ) ]
1470
+ #[ test_case( 1000 , 250 , -50 , true ; "size decreased over expectations" ) ]
1471
+ #[ test_case( 1000 , 250 , -1250 , false ; "demanding more than minus 100 is handled" ) ]
1472
+ fn test_is_new_size_ok ( old : usize , new : usize , max : i32 , ok : bool ) {
1473
+ assert_eq ! ( is_new_size_ok( old, new, max) , ok) ;
1474
+ }
1380
1475
}
0 commit comments