1
1
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2
2
3
+ import logging
3
4
import math
4
5
import unittest
6
+ from typing import cast
5
7
6
8
import executorch .backends .cadence .aot .ops_registrations # noqa
7
9
import torch
@@ -110,7 +112,121 @@ def forward(self, x):
110
112
111
113
112
114
class TestMemTransform (unittest .TestCase ):
113
- def test_optimize_cat (self ):
115
+ def _verify_cat_nop_memory_alloc (self , node : torch .fx .Node ) -> None :
116
+ spec = node .meta .get ("spec" , None )
117
+ self .assertIsNotNone (spec )
118
+ dim : int = cast (int , node .args [1 ]) if len (node .args ) > 1 else 0
119
+ outer_size = math .prod (spec .shape [:dim ])
120
+ self .assertEqual (
121
+ outer_size ,
122
+ 1 ,
123
+ f"{ node = } has wrong outer size: { outer_size = } , expected 1." ,
124
+ )
125
+ inner_dim_elements = math .prod (spec .shape [dim + 1 :]) * spec .dtype .itemsize
126
+ dim_offset = 0
127
+ for arg in cast (list [torch .fx .Node ], node .args [0 ]):
128
+ arg_spec = arg .meta .get ("spec" , None )
129
+ self .assertEqual (arg_spec .mem_id , spec .mem_id )
130
+ self .assertEqual (
131
+ arg_spec .mem_offset ,
132
+ spec .mem_offset + dim_offset * inner_dim_elements ,
133
+ f"{ arg = } for node { node = } has wrong memory offset: { arg_spec .mem_offset = } { dim_offset = } for cat on { dim = } , but output has { spec .mem_offset = } " ,
134
+ )
135
+ dim_offset += arg_spec .shape [dim ]
136
+
137
+ def _verify_slice_nop_memory_alloc (self , node : torch .fx .Node ) -> None :
138
+ spec = node .meta .get ("spec" , None )
139
+ self .assertIsNotNone (spec )
140
+ dim : int = cast (int , node .args [1 ]) if len (node .args ) > 1 else 0
141
+ outer_size = math .prod (spec .shape [:dim ])
142
+ self .assertEqual (
143
+ outer_size ,
144
+ 1 ,
145
+ f"{ node = } has wrong outer size: { outer_size = } , expected 1." ,
146
+ )
147
+ inner_dim_elements = math .prod (spec .shape [dim + 1 :]) * spec .dtype .itemsize
148
+ start : int = (
149
+ cast (int , node .args [2 ])
150
+ if (len (node .args ) > 2 and node .args [2 ] is not None )
151
+ else 0
152
+ )
153
+ arg = cast (torch .fx .Node , node .args [0 ])
154
+ arg_spec = arg .meta .get ("spec" , None )
155
+ self .assertEqual (arg_spec .mem_id , spec .mem_id )
156
+ self .assertEqual (
157
+ spec .mem_offset ,
158
+ arg_spec .mem_offset + start * inner_dim_elements ,
159
+ f"{ arg = } for node { node = } has wrong memory offset: { arg_spec .mem_offset = } { start = } for slice on { dim = } , but output has { spec .mem_offset = } " ,
160
+ )
161
+
162
+ def _verify_select_nop_memory_alloc (self , node : torch .fx .Node ) -> None :
163
+ spec = node .meta .get ("spec" , None )
164
+ self .assertIsNotNone (spec )
165
+ dim : int = cast (int , node .args [1 ]) if len (node .args ) > 1 else 0
166
+ outer_size = math .prod (spec .shape [:dim ])
167
+ self .assertEqual (
168
+ outer_size ,
169
+ 1 ,
170
+ f"{ node = } has wrong outer size: { outer_size = } , expected 1." ,
171
+ )
172
+ inner_dim_elements = math .prod (spec .shape [dim :]) * spec .dtype .itemsize
173
+ index : int = (
174
+ cast (int , node .args [2 ])
175
+ if (len (node .args ) > 2 and node .args [2 ] is not None )
176
+ else 0
177
+ )
178
+ arg = cast (torch .fx .Node , node .args [0 ])
179
+ arg_spec = arg .meta .get ("spec" , None )
180
+ self .assertEqual (arg_spec .mem_id , spec .mem_id )
181
+ self .assertEqual (
182
+ spec .mem_offset ,
183
+ arg_spec .mem_offset + index * inner_dim_elements ,
184
+ f"{ arg = } for node { node = } has wrong memory offset: { arg_spec .mem_offset = } for select on { dim = } { index = } , "
185
+ f"but output has { spec .mem_offset = } "
186
+ f"{ spec = } { arg_spec = } " ,
187
+ )
188
+
189
+ def verify_nop_memory_alloc (self , graph_module ):
190
+ for node in graph_module .graph .find_nodes (
191
+ op = "call_function" , target = torch .ops .aten ._cat_nop .out
192
+ ):
193
+ self ._verify_cat_nop_memory_alloc (node )
194
+
195
+ for node in graph_module .graph .find_nodes (
196
+ op = "call_function" , target = torch .ops .aten ._slice_copy_nop .Tensor_out
197
+ ):
198
+ self ._verify_slice_nop_memory_alloc (node )
199
+
200
+ for node in graph_module .graph .find_nodes (
201
+ op = "call_function" , target = torch .ops .aten ._select_copy_nop .int_out
202
+ ):
203
+ self ._verify_select_nop_memory_alloc (node )
204
+
205
+ def test_optimize_cat_on_placeholders (self ):
206
+ class Cat (torch .nn .Module ):
207
+ def forward (self , x , y ):
208
+ return torch .ops .aten .cat ((x , y ))
209
+
210
+ x = torch .ones (3 , 6 )
211
+ y = torch .ones (2 , 6 )
212
+ # Optimizing cat ops is only at opt_level 2+, and requires the memory planning
213
+ # pass to run:
214
+ graph_module = (
215
+ compiler .export_to_executorch_gen_etrecord (
216
+ Cat (), (x , y ), opt_level = 2 , mem_algo = 1
217
+ )
218
+ .exported_program ()
219
+ .graph_module
220
+ )
221
+ logging .info (f"graph_module: { graph_module .print_readable (print_output = False )} " )
222
+ graph_module .graph .eliminate_dead_code ()
223
+ # Assert that cat op is optimized away
224
+ self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
225
+ # Assert that cat op is replaced by its nop version post optimization
226
+ self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
227
+ self .verify_nop_memory_alloc (graph_module )
228
+
229
+ def test_optimize_cat_outermost (self ):
114
230
class OptimizeCatFeasible1 (torch .nn .Module ):
115
231
def forward (self , x , y ):
116
232
x1 = torch .add (x , 2.4 , 3.1 )
@@ -135,7 +251,9 @@ def forward(self, x, y):
135
251
self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
136
252
# Assert that cat op is replaced by its nop version post optimization
137
253
self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
254
+ self .verify_nop_memory_alloc (graph_module )
138
255
256
+ def test_optimize_cat_non_outermost (self ):
139
257
class OptimizeCatFeasible2 (torch .nn .Module ):
140
258
def forward (self , x , y ):
141
259
x1 = torch .add (x , 2.4 , 3.1 )
@@ -160,7 +278,9 @@ def forward(self, x, y):
160
278
self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
161
279
# Assert that cat op is replaced by its nop version post optimization
162
280
self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
281
+ self .verify_nop_memory_alloc (graph_module )
163
282
283
+ def test_no_optimize_cat_non_outermost (self ):
164
284
class OptimizeCatInfeasible1 (torch .nn .Module ):
165
285
def forward (self , x , y ):
166
286
x1 = torch .add (x , 2.4 , 3.1 )
@@ -184,7 +304,9 @@ def forward(self, x, y):
184
304
# Assert that cat op is not optimized away, since the concat is not
185
305
# along the outermost dim
186
306
self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
307
+ self .verify_nop_memory_alloc (graph_module )
187
308
309
+ def test_no_optimize_cat_non_outermost1 (self ):
188
310
class OptimizeCatInfeasible2 (torch .nn .Module ):
189
311
def forward (self , x , y ):
190
312
x1 = torch .add (x , 2.4 , 3.1 )
@@ -209,6 +331,7 @@ def forward(self, x, y):
209
331
# offsets are not multiple of 8 bytes, and the cat is not the output
210
332
# of the graph.
211
333
self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
334
+ self .verify_nop_memory_alloc (graph_module )
212
335
213
336
def test_optimize_cat_with_slice (self ):
214
337
class OptimizeCatSliceFeasible (torch .nn .Module ):
@@ -237,6 +360,7 @@ def forward(self, x):
237
360
graph_module .graph .eliminate_dead_code ()
238
361
# Assert that cat op is optimized away
239
362
self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
363
+ self .verify_nop_memory_alloc (graph_module )
240
364
241
365
def test_optimize_cat_with_slice_infeasible (self ):
242
366
class OptimizeCatSliceInfeasible (torch .nn .Module ):
@@ -262,6 +386,7 @@ def forward(self, x, y):
262
386
graph_module .graph .eliminate_dead_code ()
263
387
# Assert that cat op is not optimized away
264
388
self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
389
+ self .verify_nop_memory_alloc (graph_module )
265
390
266
391
def test_optimize_slice_Tensor (self ):
267
392
class SliceTensor (torch .nn .Module ):
@@ -323,6 +448,7 @@ def forward(self, x, y, z):
323
448
self .assertEqual (
324
449
count_node (graph_module , torch .ops .aten ._slice_copy_nop .Tensor_out ), 3
325
450
)
451
+ self .verify_nop_memory_alloc (graph_module )
326
452
327
453
def test_optimize_select_Tensor (self ):
328
454
class SelectTensor (torch .nn .Module ):
@@ -387,6 +513,7 @@ def forward(self, x, y, z):
387
513
self .assertEqual (
388
514
count_node (graph_module , torch .ops .aten ._select_copy_nop .int_out ), 3
389
515
)
516
+ self .verify_nop_memory_alloc (graph_module )
390
517
391
518
# TODO: Test fails due to memory planning
392
519
@unittest .expectedFailure
@@ -416,6 +543,32 @@ def forward(self, x, y):
416
543
graph_module .graph .eliminate_dead_code ()
417
544
# Assert that cat op is not optimized away
418
545
self .assertEqual (count_node (graph_module , exir_ops .edge .aten .cat .default ), 1 )
546
+ self .verify_nop_memory_alloc (graph_module )
547
+
548
+ def test_optimize_cat_then_slice_on_mutable_buffer (self ):
549
+ class CatWithPadding (torch .nn .Module ):
550
+ def __init__ (self , padding_shape ):
551
+ super ().__init__ ()
552
+ zeros = torch .zeros (padding_shape )
553
+ self .register_buffer ("padding" , zeros )
554
+
555
+ def forward (self , x , y ):
556
+ x = x .view (3 , 5 )
557
+ cat = torch .ops .aten .cat ((x , self .padding .clone ()))
558
+ slice_copy = torch .ops .aten .slice (cat , dim = 0 , start = x .shape [0 ])
559
+ self .padding .copy_ (slice_copy )
560
+ return cat .view (- 1 ) + y
561
+
562
+ x = torch .ones (15 )
563
+ y = torch .ones (1 )
564
+ et_prog_manager = compiler .export_to_executorch_gen_etrecord (
565
+ CatWithPadding ((1 , 5 )), (x , y ), opt_level = 3
566
+ )
567
+ graph_module = et_prog_manager .exported_program ().graph_module
568
+ logging .info (f"graph_module: { graph_module .print_readable (print_output = False )} " )
569
+ self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
570
+ self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
571
+ self .verify_nop_memory_alloc (graph_module )
419
572
420
573
def test_optimize_cat_with_view (self ):
421
574
class CatViewFeasible (torch .nn .Module ):
@@ -442,6 +595,7 @@ def forward(self, x, y):
442
595
# Assert that cat op is optimized away
443
596
self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 1 )
444
597
self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
598
+ self .verify_nop_memory_alloc (graph_module )
445
599
446
600
def test_no_optimize_cat_with_repeated_args (self ):
447
601
class CatViewInfeasible (torch .nn .Module ):
@@ -465,6 +619,7 @@ def forward(self, x):
465
619
# Assert that cat op is not optimized away
466
620
self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
467
621
self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 0 )
622
+ self .verify_nop_memory_alloc (graph_module )
468
623
469
624
def test_no_optimize_cat_with_placeholder (self ):
470
625
class CatViewInfeasible (torch .nn .Module ):
@@ -492,6 +647,7 @@ def forward(self, x, y):
492
647
# Assert that cat op is not optimized away
493
648
self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 1 )
494
649
self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 0 )
650
+ self .verify_nop_memory_alloc (graph_module )
495
651
496
652
def test_no_optimize_cat (self ) -> None :
497
653
class Model (torch .nn .Module ):
@@ -522,6 +678,7 @@ def forward(self, x) -> torch.Tensor:
522
678
count_node (graph_module , torch .ops .aten ._slice_copy_nop .Tensor_out ), 2
523
679
)
524
680
self .assertEqual (count_node (graph_module , memory .view ), 2 )
681
+ self .verify_nop_memory_alloc (graph_module )
525
682
526
683
def test_optimize_slice_copy (self ) -> None :
527
684
class Model (torch .nn .Module ):
@@ -553,6 +710,7 @@ def forward(self, x) -> torch.Tensor:
553
710
count_node (graph_module , torch .ops .aten ._slice_copy_nop .Tensor_out ), 0
554
711
)
555
712
self .assertEqual (count_node (graph_module , memory .view ), 2 )
713
+ self .verify_nop_memory_alloc (graph_module )
556
714
557
715
def test_cat_then_cat (self ) -> None :
558
716
class Model (torch .nn .Module ):
@@ -579,6 +737,7 @@ def forward(self, x) -> torch.Tensor:
579
737
graph_module .print_readable ()
580
738
self .assertEqual (count_node (graph_module , torch .ops .aten ._cat_nop .out ), 2 )
581
739
self .assertEqual (count_node (graph_module , torch .ops .aten .cat .out ), 0 )
740
+ self .verify_nop_memory_alloc (graph_module )
582
741
583
742
def test_view_for_unallocated_output (self ):
584
743
class Model (torch .nn .Module ):
@@ -602,3 +761,4 @@ def forward(self, x, y):
602
761
.graph_module
603
762
)
604
763
self .assertEqual (count_node (graph_module , memory .view ), 1 )
764
+ self .verify_nop_memory_alloc (graph_module )
0 commit comments