From 3223f3699cf4134a65c96f97674b5bdb58efc28a Mon Sep 17 00:00:00 2001 From: Hardik Sharma Date: Tue, 4 Feb 2025 10:40:07 -0800 Subject: [PATCH] Add a pass to replace nodes with empty tensors with full. (#8130) Summary: Remove subgraphs of ops that produce empty tensors at the end. `ReplaceEmptyTensorsWithFullPass` both does the replacement and dead code elimination. Reviewed By: zonglinpeng Differential Revision: D68907459 --- backends/cadence/aot/graph_builder.py | 10 +++- backends/cadence/aot/replace_ops.py | 22 +++++++ .../cadence/aot/tests/test_graph_builder.py | 3 +- .../aot/tests/test_replace_ops_passes.py | 60 ++++++++++++++++++- 4 files changed, 90 insertions(+), 5 deletions(-) diff --git a/backends/cadence/aot/graph_builder.py b/backends/cadence/aot/graph_builder.py index fc9441891a3..9eea77d6aa6 100644 --- a/backends/cadence/aot/graph_builder.py +++ b/backends/cadence/aot/graph_builder.py @@ -6,10 +6,16 @@ from typing import Optional, Sequence, Union import torch -from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue +from executorch.exir.pass_base import ( + Argument, + ExportPass, + NodeMetadata, + PassResult, + ProxyValue, +) from torch._dispatch.python import enable_python_dispatcher from torch._subclasses import FakeTensor, FakeTensorMode -from torch.fx.node import Argument, Target +from torch.fx.node import Target from torch.utils import _pytree as pytree diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index 89ef821c569..cccc56effa9 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -2071,10 +2071,32 @@ def call_operator( ) +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class ReplaceEmptyTensorsWithFullPass(ExportPass): + """Replaces nodes that produce empty tensors with full nodes.""" + + def call_operator(self, op, args, kwargs, meta): + val = meta.data.get("val", None) + if isinstance(val, torch.Tensor) and val.numel() == 0: + return super().call_operator( + exir_ops.edge.aten.full.default, + args=(val.shape, 0), + kwargs={"dtype": val.dtype}, + meta=meta, + ) + return super().call_operator(op, args, kwargs, meta) + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + ret = super().call(graph_module) + modified = ret.graph_module.graph.eliminate_dead_code() or ret.modified + return PassResult(ret.graph_module, modified) + + # This class encapsulates all the functions that replace/switch one op in the # graph with another. class CadenceReplaceOpsInGraph: passes = [ + ReplaceEmptyTensorsWithFullPass, ReplaceFunctionallyEquivalentOpTargets, ReplaceTCopyWithTransposePass, ReplacePermuteWithTransposePass, diff --git a/backends/cadence/aot/tests/test_graph_builder.py b/backends/cadence/aot/tests/test_graph_builder.py index ebef97be52a..750c02fef9f 100644 --- a/backends/cadence/aot/tests/test_graph_builder.py +++ b/backends/cadence/aot/tests/test_graph_builder.py @@ -26,7 +26,6 @@ def test_graph_with_single_im2row(self) -> None: channels_last = False im2row = builder.call_operator( exir_ops.edge.cadence.im2row.default, - # pyre-ignore ( x, (2, 2), @@ -80,7 +79,7 @@ def _get_inner_graph(self, x_shape: Sequence[int]) -> torch.fx.GraphModule: x = builder.placeholder("x", torch.randn(*x_shape)) add = builder.call_operator( exir_ops.edge.aten.add.Tensor, - (x, x), # pyre-ignore + (x, x), ) builder.output([x, add]) gm = builder.get_graph_module() diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index fb6f134fd95..1282e4e9b25 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -7,7 +7,10 @@ import torch.nn.functional as F from executorch.backends.cadence.aot import compiler from executorch.backends.cadence.aot.compiler import export_to_edge, quantize_pt2 -from executorch.backends.cadence.aot.graph_builder import single_op_builder +from executorch.backends.cadence.aot.graph_builder import ( + GraphBuilder, + single_op_builder, +) from executorch.backends.cadence.aot.pass_utils import count_node from executorch.backends.cadence.aot.replace_ops import ( ForceChannelLastForConvPass, @@ -18,6 +21,7 @@ ReplaceConstantPadNdWithSlicePass, ReplaceConvolutionOptionalArgsWithConcreteArgsPass, ReplaceConvWithIm2RowAndLinear, + ReplaceEmptyTensorsWithFullPass, ReplaceFunctionallyEquivalentOpTargets, ReplaceIm2RowWithViewPass, ReplaceLinearWithFullyConnectedOpPass, @@ -1681,3 +1685,57 @@ def test_cat_insert_transpose(self): count_node(gm_after_pass, exir_ops.edge.aten.transpose_copy.int), 3, ) + + +class TestReplaceEmptyTensorsWithFullPass(unittest.TestCase): + def _get_slice_empty_gm(self) -> torch.fx.GraphModule: + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(4)) + # This is empty (numel == 0). + slice0 = builder.call_operator( + exir_ops.edge.aten.slice_copy.Tensor, (x, 0, 0, 0) + ) + # Copy of x. + slice1 = builder.call_operator(exir_ops.edge.aten.slice_copy.Tensor, (x,)) + cat = builder.call_operator( + exir_ops.edge.aten.cat.default, + ((slice0, slice1),), + ) + builder.output([cat]) + return builder.get_graph_module() + + def test_empty_slice(self): + gm = self._get_slice_empty_gm() + self.assertEqual( + len( + gm.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor + ) + ), + 2, + ) + self.assertEqual( + len( + gm.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.full.default + ) + ), + 0, + ) + updated_gm = ReplaceEmptyTensorsWithFullPass()(gm).graph_module + self.assertEqual( + len( + updated_gm.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor + ) + ), + 1, + ) + self.assertEqual( + len( + updated_gm.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.full.default + ) + ), + 1, + )