Skip to content

Commit ff9cb64

Browse files
cccclaifacebook-github-bot
authored andcommitted
Add a pass to convert rank-0 tensor to rank-1 tensor (#8298)
Summary: Test with following code, and no long see the error/warning to complain rank-0 tensor ``` class Model(torch.nn.Module): def forward(self, x, y): return x + y model = Model() model.eval() example_inputs = (torch.tensor(1.0), torch.tensor(2.0)) exported_program_manager_aten = torch.export.export(model, example_inputs) exported_program_manager_edge = executorch.exir.to_edge( exported_program_manager_aten ).transform([Rank0ToRank1Pass()]) delegated_module = to_backend( CoreMLBackend.__name__, exported_program_manager_edge.exported_program(), [] ) ``` Differential Revision: D69281867
1 parent 70143a2 commit ff9cb64

File tree

3 files changed

+77
-0
lines changed

3 files changed

+77
-0
lines changed
+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import torch
2+
from executorch.exir.pass_base import ExportPass, PassResult
3+
4+
5+
class Rank0ToRank1Pass(ExportPass):
6+
"""
7+
Replace Rank-0 Tensor to Rank-1 Tensor for all the inputs.
8+
"""
9+
10+
def __init__(self) -> None:
11+
super().__init__()
12+
13+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
14+
for node in graph_module.graph.nodes:
15+
if node.op == "placeholder" and node.meta["val"].shape == ():
16+
node.meta["val"] = node.meta["val"].reshape(1, 1)
17+
graph_module.recompile()
18+
return PassResult(graph_module, True)

backends/transforms/targets.bzl

+27
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,20 @@ def define_common_targets():
187187
],
188188
)
189189

190+
runtime.python_library(
191+
name = "rank_0_to_rank_1",
192+
srcs = [
193+
"rank_0_to_rank_1.py",
194+
],
195+
visibility = [
196+
"//executorch/backends/...",
197+
],
198+
deps = [
199+
"//caffe2:torch",
200+
"//executorch/exir:pass_base",
201+
],
202+
)
203+
190204
runtime.python_test(
191205
name = "test_duplicate_dynamic_quant_chain",
192206
srcs = [
@@ -200,3 +214,16 @@ def define_common_targets():
200214
"//executorch/exir:lib",
201215
],
202216
)
217+
218+
219+
runtime.python_test(
220+
name = "test_rank_0_to_rank_1",
221+
srcs = [
222+
"test/test_rank_0_to_rank_1.py",
223+
],
224+
deps = [
225+
"//caffe2:torch",
226+
"//executorch/exir:lib",
227+
":rank_0_to_rank_1",
228+
],
229+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import unittest
2+
3+
import torch
4+
from executorch.backends.transforms.rank_0_to_rank_1 import Rank0ToRank1Pass
5+
from executorch.exir import to_edge
6+
class TestRank0ToRank1Pass(unittest.TestCase):
7+
def test_pass(
8+
self,
9+
):
10+
class Model(torch.nn.Module):
11+
def forward(self, x, y):
12+
return x + y
13+
14+
model = Model()
15+
model.eval()
16+
17+
example_inputs = (torch.tensor(1.0), torch.tensor(2.0))
18+
aten = torch.export.export(model, example_inputs)
19+
20+
# Check that the input rank is 0
21+
for node in aten.graph.nodes:
22+
if node.op == "placeholder":
23+
self.assertTrue(node.meta["val"].shape == ())
24+
25+
edge = to_edge(
26+
aten
27+
).transform([Rank0ToRank1Pass()])
28+
29+
# Check that the input rank is 1
30+
for node in edge.exported_program().graph.nodes:
31+
if node.op == "placeholder":
32+
self.assertTrue(node.meta["val"].shape == (1, 1))

0 commit comments

Comments
 (0)