Skip to content

Add a pass to replace nodes with empty tensors with full. #8130

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions backends/cadence/aot/graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,16 @@
from typing import Optional, Sequence, Union

import torch
from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue
from executorch.exir.pass_base import (
Argument,
ExportPass,
NodeMetadata,
PassResult,
ProxyValue,
)
from torch._dispatch.python import enable_python_dispatcher
from torch._subclasses import FakeTensor, FakeTensorMode
from torch.fx.node import Argument, Target
from torch.fx.node import Target
from torch.utils import _pytree as pytree


Expand Down
22 changes: 22 additions & 0 deletions backends/cadence/aot/replace_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2071,10 +2071,32 @@ def call_operator(
)


@register_cadence_pass(CadencePassAttribute(opt_level=1))
class ReplaceEmptyTensorsWithFullPass(ExportPass):
"""Replaces nodes that produce empty tensors with full nodes."""

def call_operator(self, op, args, kwargs, meta):
val = meta.data.get("val", None)
if isinstance(val, torch.Tensor) and val.numel() == 0:
return super().call_operator(
exir_ops.edge.aten.full.default,
args=(val.shape, 0),
kwargs={"dtype": val.dtype},
meta=meta,
)
return super().call_operator(op, args, kwargs, meta)

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
ret = super().call(graph_module)
modified = ret.graph_module.graph.eliminate_dead_code() or ret.modified
return PassResult(ret.graph_module, modified)


# This class encapsulates all the functions that replace/switch one op in the
# graph with another.
class CadenceReplaceOpsInGraph:
passes = [
ReplaceEmptyTensorsWithFullPass,
ReplaceFunctionallyEquivalentOpTargets,
ReplaceTCopyWithTransposePass,
ReplacePermuteWithTransposePass,
Expand Down
3 changes: 1 addition & 2 deletions backends/cadence/aot/tests/test_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def test_graph_with_single_im2row(self) -> None:
channels_last = False
im2row = builder.call_operator(
exir_ops.edge.cadence.im2row.default,
# pyre-ignore
(
x,
(2, 2),
Expand Down Expand Up @@ -80,7 +79,7 @@ def _get_inner_graph(self, x_shape: Sequence[int]) -> torch.fx.GraphModule:
x = builder.placeholder("x", torch.randn(*x_shape))
add = builder.call_operator(
exir_ops.edge.aten.add.Tensor,
(x, x), # pyre-ignore
(x, x),
)
builder.output([x, add])
gm = builder.get_graph_module()
Expand Down
60 changes: 59 additions & 1 deletion backends/cadence/aot/tests/test_replace_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
import torch.nn.functional as F
from executorch.backends.cadence.aot import compiler
from executorch.backends.cadence.aot.compiler import export_to_edge, quantize_pt2
from executorch.backends.cadence.aot.graph_builder import single_op_builder
from executorch.backends.cadence.aot.graph_builder import (
GraphBuilder,
single_op_builder,
)
from executorch.backends.cadence.aot.pass_utils import count_node
from executorch.backends.cadence.aot.replace_ops import (
ForceChannelLastForConvPass,
Expand All @@ -18,6 +21,7 @@
ReplaceConstantPadNdWithSlicePass,
ReplaceConvolutionOptionalArgsWithConcreteArgsPass,
ReplaceConvWithIm2RowAndLinear,
ReplaceEmptyTensorsWithFullPass,
ReplaceFunctionallyEquivalentOpTargets,
ReplaceIm2RowWithViewPass,
ReplaceLinearWithFullyConnectedOpPass,
Expand Down Expand Up @@ -1681,3 +1685,57 @@ def test_cat_insert_transpose(self):
count_node(gm_after_pass, exir_ops.edge.aten.transpose_copy.int),
3,
)


class TestReplaceEmptyTensorsWithFullPass(unittest.TestCase):
def _get_slice_empty_gm(self) -> torch.fx.GraphModule:
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(4))
# This is empty (numel == 0).
slice0 = builder.call_operator(
exir_ops.edge.aten.slice_copy.Tensor, (x, 0, 0, 0)
)
# Copy of x.
slice1 = builder.call_operator(exir_ops.edge.aten.slice_copy.Tensor, (x,))
cat = builder.call_operator(
exir_ops.edge.aten.cat.default,
((slice0, slice1),),
)
builder.output([cat])
return builder.get_graph_module()

def test_empty_slice(self):
gm = self._get_slice_empty_gm()
self.assertEqual(
len(
gm.graph.find_nodes(
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
)
),
2,
)
self.assertEqual(
len(
gm.graph.find_nodes(
op="call_function", target=exir_ops.edge.aten.full.default
)
),
0,
)
updated_gm = ReplaceEmptyTensorsWithFullPass()(gm).graph_module
self.assertEqual(
len(
updated_gm.graph.find_nodes(
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
)
),
1,
)
self.assertEqual(
len(
updated_gm.graph.find_nodes(
op="call_function", target=exir_ops.edge.aten.full.default
)
),
1,
)
Loading