Skip to content

Commit 9cc8669

Browse files
committed
Update
[ghstack-poisoned]
2 parents ad860de + 1655fc5 commit 9cc8669

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+2122
-1439
lines changed

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,16 @@ def get_input_qparams(node: Node) -> dict[int, QuantArgs]:
3232
Raises a ValueError if the node doesn't have any parameters set.
3333
"""
3434
if "input_qparams" not in node.meta.keys():
35-
raise ValueError(f"No input quantization parameter found in node {node}")
35+
raise ValueError(
36+
f"No input quantization parameter found in node {node}\n"
37+
f"original_aten={node.meta.get('original_aten', 'None')}"
38+
)
3639
input_qparams = cast(dict[int, QuantArgs], node.meta["input_qparams"])
3740
if len(input_qparams) == 0:
38-
raise ValueError(f"No input quantization parameter found in node {node}")
41+
raise ValueError(
42+
f"No input quantization parameter found in node {node}\n"
43+
f"original_aten={node.meta.get('original_aten', 'None')}"
44+
)
3945
return input_qparams
4046

4147

@@ -45,11 +51,17 @@ def get_output_qparams(node: Node) -> dict[int, QuantArgs]:
4551
Raises a ValueError if the node doesn't have any parameters set.
4652
"""
4753
if "output_qparams" not in node.meta.keys():
48-
raise ValueError(f"No output quantization parameter found in node {node}")
49-
input_qparams = cast(dict[int, QuantArgs], node.meta["output_qparams"])
50-
if len(input_qparams) == 0:
51-
raise ValueError(f"No output quantization parameter found in node {node}")
52-
return input_qparams
54+
raise ValueError(
55+
f"No output quantization parameter found in node {node}\n"
56+
f"original_aten={node.meta.get('original_aten', 'None')}"
57+
)
58+
output_qparams = cast(dict[int, QuantArgs], node.meta["output_qparams"])
59+
if len(output_qparams) == 0:
60+
raise ValueError(
61+
f"No output quantization parameter found in node {node}\n"
62+
f"original_aten={node.meta.get('original_aten', 'None')}"
63+
)
64+
return output_qparams
5365

5466

5567
class FoldAndAnnotateQParamsPass(ExportPass):

backends/cadence/aot/memory_planning.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,8 @@ def print_memory_planning_info(
367367

368368
# Print the memory usage per memory space as a table
369369
logging.info(
370-
tabulate(
370+
"\n"
371+
+ tabulate(
371372
memory_usage_table,
372373
headers=[
373374
"Memory Space",
@@ -398,7 +399,8 @@ def print_memory_planning_info(
398399

399400
# Print the total memory usage as a table
400401
logging.info(
401-
tabulate(
402+
"\n"
403+
+ tabulate(
402404
total_memory_usage_table,
403405
tablefmt="outline",
404406
)

backends/cadence/aot/tests/test_memory_passes.py

Lines changed: 161 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
22

3+
import logging
34
import math
45
import unittest
6+
from typing import cast
57

68
import executorch.backends.cadence.aot.ops_registrations # noqa
79
import torch
@@ -110,7 +112,121 @@ def forward(self, x):
110112

111113

112114
class TestMemTransform(unittest.TestCase):
113-
def test_optimize_cat(self):
115+
def _verify_cat_nop_memory_alloc(self, node: torch.fx.Node) -> None:
116+
spec = node.meta.get("spec", None)
117+
self.assertIsNotNone(spec)
118+
dim: int = cast(int, node.args[1]) if len(node.args) > 1 else 0
119+
outer_size = math.prod(spec.shape[:dim])
120+
self.assertEqual(
121+
outer_size,
122+
1,
123+
f"{node=} has wrong outer size: {outer_size=}, expected 1.",
124+
)
125+
inner_dim_elements = math.prod(spec.shape[dim + 1 :]) * spec.dtype.itemsize
126+
dim_offset = 0
127+
for arg in cast(list[torch.fx.Node], node.args[0]):
128+
arg_spec = arg.meta.get("spec", None)
129+
self.assertEqual(arg_spec.mem_id, spec.mem_id)
130+
self.assertEqual(
131+
arg_spec.mem_offset,
132+
spec.mem_offset + dim_offset * inner_dim_elements,
133+
f"{arg=} for node {node=} has wrong memory offset: {arg_spec.mem_offset=} {dim_offset=} for cat on {dim=}, but output has {spec.mem_offset=}",
134+
)
135+
dim_offset += arg_spec.shape[dim]
136+
137+
def _verify_slice_nop_memory_alloc(self, node: torch.fx.Node) -> None:
138+
spec = node.meta.get("spec", None)
139+
self.assertIsNotNone(spec)
140+
dim: int = cast(int, node.args[1]) if len(node.args) > 1 else 0
141+
outer_size = math.prod(spec.shape[:dim])
142+
self.assertEqual(
143+
outer_size,
144+
1,
145+
f"{node=} has wrong outer size: {outer_size=}, expected 1.",
146+
)
147+
inner_dim_elements = math.prod(spec.shape[dim + 1 :]) * spec.dtype.itemsize
148+
start: int = (
149+
cast(int, node.args[2])
150+
if (len(node.args) > 2 and node.args[2] is not None)
151+
else 0
152+
)
153+
arg = cast(torch.fx.Node, node.args[0])
154+
arg_spec = arg.meta.get("spec", None)
155+
self.assertEqual(arg_spec.mem_id, spec.mem_id)
156+
self.assertEqual(
157+
spec.mem_offset,
158+
arg_spec.mem_offset + start * inner_dim_elements,
159+
f"{arg=} for node {node=} has wrong memory offset: {arg_spec.mem_offset=} {start=} for slice on {dim=}, but output has {spec.mem_offset=}",
160+
)
161+
162+
def _verify_select_nop_memory_alloc(self, node: torch.fx.Node) -> None:
163+
spec = node.meta.get("spec", None)
164+
self.assertIsNotNone(spec)
165+
dim: int = cast(int, node.args[1]) if len(node.args) > 1 else 0
166+
outer_size = math.prod(spec.shape[:dim])
167+
self.assertEqual(
168+
outer_size,
169+
1,
170+
f"{node=} has wrong outer size: {outer_size=}, expected 1.",
171+
)
172+
inner_dim_elements = math.prod(spec.shape[dim:]) * spec.dtype.itemsize
173+
index: int = (
174+
cast(int, node.args[2])
175+
if (len(node.args) > 2 and node.args[2] is not None)
176+
else 0
177+
)
178+
arg = cast(torch.fx.Node, node.args[0])
179+
arg_spec = arg.meta.get("spec", None)
180+
self.assertEqual(arg_spec.mem_id, spec.mem_id)
181+
self.assertEqual(
182+
spec.mem_offset,
183+
arg_spec.mem_offset + index * inner_dim_elements,
184+
f"{arg=} for node {node=} has wrong memory offset: {arg_spec.mem_offset=} for select on {dim=} {index=}, "
185+
f"but output has {spec.mem_offset=}"
186+
f"{spec=} {arg_spec=}",
187+
)
188+
189+
def verify_nop_memory_alloc(self, graph_module):
190+
for node in graph_module.graph.find_nodes(
191+
op="call_function", target=torch.ops.aten._cat_nop.out
192+
):
193+
self._verify_cat_nop_memory_alloc(node)
194+
195+
for node in graph_module.graph.find_nodes(
196+
op="call_function", target=torch.ops.aten._slice_copy_nop.Tensor_out
197+
):
198+
self._verify_slice_nop_memory_alloc(node)
199+
200+
for node in graph_module.graph.find_nodes(
201+
op="call_function", target=torch.ops.aten._select_copy_nop.int_out
202+
):
203+
self._verify_select_nop_memory_alloc(node)
204+
205+
def test_optimize_cat_on_placeholders(self):
206+
class Cat(torch.nn.Module):
207+
def forward(self, x, y):
208+
return torch.ops.aten.cat((x, y))
209+
210+
x = torch.ones(3, 6)
211+
y = torch.ones(2, 6)
212+
# Optimizing cat ops is only at opt_level 2+, and requires the memory planning
213+
# pass to run:
214+
graph_module = (
215+
compiler.export_to_executorch_gen_etrecord(
216+
Cat(), (x, y), opt_level=2, mem_algo=1
217+
)
218+
.exported_program()
219+
.graph_module
220+
)
221+
logging.info(f"graph_module: {graph_module.print_readable(print_output=False)}")
222+
graph_module.graph.eliminate_dead_code()
223+
# Assert that cat op is optimized away
224+
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0)
225+
# Assert that cat op is replaced by its nop version post optimization
226+
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
227+
self.verify_nop_memory_alloc(graph_module)
228+
229+
def test_optimize_cat_outermost(self):
114230
class OptimizeCatFeasible1(torch.nn.Module):
115231
def forward(self, x, y):
116232
x1 = torch.add(x, 2.4, 3.1)
@@ -135,7 +251,9 @@ def forward(self, x, y):
135251
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0)
136252
# Assert that cat op is replaced by its nop version post optimization
137253
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
254+
self.verify_nop_memory_alloc(graph_module)
138255

256+
def test_optimize_cat_non_outermost(self):
139257
class OptimizeCatFeasible2(torch.nn.Module):
140258
def forward(self, x, y):
141259
x1 = torch.add(x, 2.4, 3.1)
@@ -160,7 +278,9 @@ def forward(self, x, y):
160278
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0)
161279
# Assert that cat op is replaced by its nop version post optimization
162280
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
281+
self.verify_nop_memory_alloc(graph_module)
163282

283+
def test_no_optimize_cat_non_outermost(self):
164284
class OptimizeCatInfeasible1(torch.nn.Module):
165285
def forward(self, x, y):
166286
x1 = torch.add(x, 2.4, 3.1)
@@ -184,7 +304,9 @@ def forward(self, x, y):
184304
# Assert that cat op is not optimized away, since the concat is not
185305
# along the outermost dim
186306
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
307+
self.verify_nop_memory_alloc(graph_module)
187308

309+
def test_no_optimize_cat_non_outermost1(self):
188310
class OptimizeCatInfeasible2(torch.nn.Module):
189311
def forward(self, x, y):
190312
x1 = torch.add(x, 2.4, 3.1)
@@ -209,6 +331,7 @@ def forward(self, x, y):
209331
# offsets are not multiple of 8 bytes, and the cat is not the output
210332
# of the graph.
211333
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
334+
self.verify_nop_memory_alloc(graph_module)
212335

213336
def test_optimize_cat_with_slice(self):
214337
class OptimizeCatSliceFeasible(torch.nn.Module):
@@ -237,6 +360,7 @@ def forward(self, x):
237360
graph_module.graph.eliminate_dead_code()
238361
# Assert that cat op is optimized away
239362
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
363+
self.verify_nop_memory_alloc(graph_module)
240364

241365
def test_optimize_cat_with_slice_infeasible(self):
242366
class OptimizeCatSliceInfeasible(torch.nn.Module):
@@ -262,6 +386,7 @@ def forward(self, x, y):
262386
graph_module.graph.eliminate_dead_code()
263387
# Assert that cat op is not optimized away
264388
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
389+
self.verify_nop_memory_alloc(graph_module)
265390

266391
def test_optimize_slice_Tensor(self):
267392
class SliceTensor(torch.nn.Module):
@@ -323,6 +448,7 @@ def forward(self, x, y, z):
323448
self.assertEqual(
324449
count_node(graph_module, torch.ops.aten._slice_copy_nop.Tensor_out), 3
325450
)
451+
self.verify_nop_memory_alloc(graph_module)
326452

327453
def test_optimize_select_Tensor(self):
328454
class SelectTensor(torch.nn.Module):
@@ -387,6 +513,7 @@ def forward(self, x, y, z):
387513
self.assertEqual(
388514
count_node(graph_module, torch.ops.aten._select_copy_nop.int_out), 3
389515
)
516+
self.verify_nop_memory_alloc(graph_module)
390517

391518
# TODO: Test fails due to memory planning
392519
@unittest.expectedFailure
@@ -416,6 +543,32 @@ def forward(self, x, y):
416543
graph_module.graph.eliminate_dead_code()
417544
# Assert that cat op is not optimized away
418545
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 1)
546+
self.verify_nop_memory_alloc(graph_module)
547+
548+
def test_optimize_cat_then_slice_on_mutable_buffer(self):
549+
class CatWithPadding(torch.nn.Module):
550+
def __init__(self, padding_shape):
551+
super().__init__()
552+
zeros = torch.zeros(padding_shape)
553+
self.register_buffer("padding", zeros)
554+
555+
def forward(self, x, y):
556+
x = x.view(3, 5)
557+
cat = torch.ops.aten.cat((x, self.padding.clone()))
558+
slice_copy = torch.ops.aten.slice(cat, dim=0, start=x.shape[0])
559+
self.padding.copy_(slice_copy)
560+
return cat.view(-1) + y
561+
562+
x = torch.ones(15)
563+
y = torch.ones(1)
564+
et_prog_manager = compiler.export_to_executorch_gen_etrecord(
565+
CatWithPadding((1, 5)), (x, y), opt_level=3
566+
)
567+
graph_module = et_prog_manager.exported_program().graph_module
568+
logging.info(f"graph_module: {graph_module.print_readable(print_output=False)}")
569+
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0)
570+
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
571+
self.verify_nop_memory_alloc(graph_module)
419572

420573
def test_optimize_cat_with_view(self):
421574
class CatViewFeasible(torch.nn.Module):
@@ -442,6 +595,7 @@ def forward(self, x, y):
442595
# Assert that cat op is optimized away
443596
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
444597
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0)
598+
self.verify_nop_memory_alloc(graph_module)
445599

446600
def test_no_optimize_cat_with_repeated_args(self):
447601
class CatViewInfeasible(torch.nn.Module):
@@ -465,6 +619,7 @@ def forward(self, x):
465619
# Assert that cat op is not optimized away
466620
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
467621
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 0)
622+
self.verify_nop_memory_alloc(graph_module)
468623

469624
def test_no_optimize_cat_with_placeholder(self):
470625
class CatViewInfeasible(torch.nn.Module):
@@ -492,6 +647,7 @@ def forward(self, x, y):
492647
# Assert that cat op is not optimized away
493648
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
494649
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 0)
650+
self.verify_nop_memory_alloc(graph_module)
495651

496652
def test_no_optimize_cat(self) -> None:
497653
class Model(torch.nn.Module):
@@ -522,6 +678,7 @@ def forward(self, x) -> torch.Tensor:
522678
count_node(graph_module, torch.ops.aten._slice_copy_nop.Tensor_out), 2
523679
)
524680
self.assertEqual(count_node(graph_module, memory.view), 2)
681+
self.verify_nop_memory_alloc(graph_module)
525682

526683
def test_optimize_slice_copy(self) -> None:
527684
class Model(torch.nn.Module):
@@ -553,6 +710,7 @@ def forward(self, x) -> torch.Tensor:
553710
count_node(graph_module, torch.ops.aten._slice_copy_nop.Tensor_out), 0
554711
)
555712
self.assertEqual(count_node(graph_module, memory.view), 2)
713+
self.verify_nop_memory_alloc(graph_module)
556714

557715
def test_cat_then_cat(self) -> None:
558716
class Model(torch.nn.Module):
@@ -579,6 +737,7 @@ def forward(self, x) -> torch.Tensor:
579737
graph_module.print_readable()
580738
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 2)
581739
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0)
740+
self.verify_nop_memory_alloc(graph_module)
582741

583742
def test_view_for_unallocated_output(self):
584743
class Model(torch.nn.Module):
@@ -602,3 +761,4 @@ def forward(self, x, y):
602761
.graph_module
603762
)
604763
self.assertEqual(count_node(graph_module, memory.view), 1)
764+
self.verify_nop_memory_alloc(graph_module)

backends/xnnpack/test/tester/tester.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -558,10 +558,8 @@ def export(self, export_stage: Optional[Export] = None):
558558
)
559559

560560
def to_edge(self, to_edge_stage: Optional[ToEdge] = None):
561-
# TODO(T182187531): Skip dim order for now. Support dim order and its op after alpha release.
562561
if not to_edge_stage:
563562
to_edge_stage = ToEdge()
564-
to_edge_stage.edge_compile_conf._skip_dim_order = True
565563
res = self._run_stage(to_edge_stage)
566564
return res
567565

docs/source/index.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ Welcome to the ExecuTorch Documentation
44
=======================================
55

66
.. important::
7-
v0.4.0 is a beta release of ExecuTorch. As of this release, the API will
8-
follow the `API Lifecycle and Deprecation Policy <api-life-cycle.html>`__,
9-
and the ``.pte`` binary format will comply with the `Runtime Compatibility
7+
v0.4.0 was the beta release of ExecuTorch. Starting from v0.4.0, the API
8+
follows the `API Lifecycle and Deprecation Policy <api-life-cycle.html>`__,
9+
and the ``.pte`` binary format complies with the `Runtime Compatibility
1010
Policy
1111
<https://github.com/pytorch/executorch/tree/main/runtime/COMPATIBILITY.md>`__.
1212
This helps ensure that application developers can update to the latest

exir/capture/_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,7 @@ class ExecutorchBackendConfig:
9292
# If set to true, all constant tensors will be stored in a separate file,
9393
# external to the PTE file.
9494
external_constants: bool = False
95+
96+
# If set to true, all trainable weights will be stored in a separate file,
97+
# external to the PTE file.
98+
external_mutable_weights: bool = False

0 commit comments

Comments
 (0)