From 411fd3f5b58409c43a709ea6df653fcd4b137437 Mon Sep 17 00:00:00 2001 From: Chen Lai Date: Tue, 11 Feb 2025 14:03:48 -0800 Subject: [PATCH] Add a pass to convert rank-0 tensor to rank-1 tensor (#8298) Summary: Test with following code, and no long see the error/warning to complain rank-0 tensor ``` class Model(torch.nn.Module): def forward(self, x, y): return x + y model = Model() model.eval() example_inputs = (torch.tensor(1.0), torch.tensor(2.0)) exported_program_manager_aten = torch.export.export(model, example_inputs) exported_program_manager_edge = executorch.exir.to_edge( exported_program_manager_aten ).transform([Rank0ToRank1Pass()]) delegated_module = to_backend( CoreMLBackend.__name__, exported_program_manager_edge.exported_program(), [] ) ``` Differential Revision: D69281867 --- backends/transforms/rank_0_to_rank_1.py | 18 +++++++++++ backends/transforms/targets.bzl | 27 ++++++++++++++++ .../transforms/test/test_rank_0_to_rank_1.py | 32 +++++++++++++++++++ 3 files changed, 77 insertions(+) create mode 100644 backends/transforms/rank_0_to_rank_1.py create mode 100644 backends/transforms/test/test_rank_0_to_rank_1.py diff --git a/backends/transforms/rank_0_to_rank_1.py b/backends/transforms/rank_0_to_rank_1.py new file mode 100644 index 00000000000..81159efcc24 --- /dev/null +++ b/backends/transforms/rank_0_to_rank_1.py @@ -0,0 +1,18 @@ +import torch +from executorch.exir.pass_base import ExportPass, PassResult + + +class Rank0ToRank1Pass(ExportPass): + """ + Replace Rank-0 Tensor to Rank-1 Tensor for all the inputs. + """ + + def __init__(self) -> None: + super().__init__() + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + for node in graph_module.graph.nodes: + if node.op == "placeholder" and node.meta["val"].shape == (): + node.meta["val"] = node.meta["val"].reshape(1, 1) + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/transforms/targets.bzl b/backends/transforms/targets.bzl index 09ef0f59c59..c532798546d 100644 --- a/backends/transforms/targets.bzl +++ b/backends/transforms/targets.bzl @@ -187,6 +187,20 @@ def define_common_targets(): ], ) + runtime.python_library( + name = "rank_0_to_rank_1", + srcs = [ + "rank_0_to_rank_1.py", + ], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:pass_base", + ], + ) + runtime.python_test( name = "test_duplicate_dynamic_quant_chain", srcs = [ @@ -200,3 +214,16 @@ def define_common_targets(): "//executorch/exir:lib", ], ) + + + runtime.python_test( + name = "test_rank_0_to_rank_1", + srcs = [ + "test/test_rank_0_to_rank_1.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:lib", + ":rank_0_to_rank_1", + ], + ) diff --git a/backends/transforms/test/test_rank_0_to_rank_1.py b/backends/transforms/test/test_rank_0_to_rank_1.py new file mode 100644 index 00000000000..50c6357fb67 --- /dev/null +++ b/backends/transforms/test/test_rank_0_to_rank_1.py @@ -0,0 +1,32 @@ +import unittest + +import torch +from executorch.backends.transforms.rank_0_to_rank_1 import Rank0ToRank1Pass +from executorch.exir import to_edge + + +class TestRank0ToRank1Pass(unittest.TestCase): + def test_pass( + self, + ): + class Model(torch.nn.Module): + def forward(self, x, y): + return x + y + + model = Model() + model.eval() + + example_inputs = (torch.tensor(1.0), torch.tensor(2.0)) + aten = torch.export.export(model, example_inputs) + + # Check that the input rank is 0 + for node in aten.graph.nodes: + if node.op == "placeholder": + self.assertTrue(node.meta["val"].shape == ()) + + edge = to_edge(aten).transform([Rank0ToRank1Pass()]) + + # Check that the input rank is 1 + for node in edge.exported_program().graph.nodes: + if node.op == "placeholder": + self.assertTrue(node.meta["val"].shape == (1, 1))