Skip to content

Add a pass to convert rank-0 tensor to rank-1 tensor #8298

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions backends/transforms/rank_0_to_rank_1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch
from executorch.exir.pass_base import ExportPass, PassResult


class Rank0ToRank1Pass(ExportPass):
"""
Replace Rank-0 Tensor to Rank-1 Tensor for all the inputs.
"""

def __init__(self) -> None:
super().__init__()

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
for node in graph_module.graph.nodes:
if node.op == "placeholder" and node.meta["val"].shape == ():
node.meta["val"] = node.meta["val"].reshape(1, 1)
graph_module.recompile()
return PassResult(graph_module, True)
27 changes: 27 additions & 0 deletions backends/transforms/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,20 @@ def define_common_targets():
],
)

runtime.python_library(
name = "rank_0_to_rank_1",
srcs = [
"rank_0_to_rank_1.py",
],
visibility = [
"//executorch/backends/...",
],
deps = [
"//caffe2:torch",
"//executorch/exir:pass_base",
],
)

runtime.python_test(
name = "test_duplicate_dynamic_quant_chain",
srcs = [
Expand All @@ -200,3 +214,16 @@ def define_common_targets():
"//executorch/exir:lib",
],
)


runtime.python_test(
name = "test_rank_0_to_rank_1",
srcs = [
"test/test_rank_0_to_rank_1.py",
],
deps = [
"//caffe2:torch",
"//executorch/exir:lib",
":rank_0_to_rank_1",
],
)
32 changes: 32 additions & 0 deletions backends/transforms/test/test_rank_0_to_rank_1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import unittest

import torch
from executorch.backends.transforms.rank_0_to_rank_1 import Rank0ToRank1Pass
from executorch.exir import to_edge


class TestRank0ToRank1Pass(unittest.TestCase):
def test_pass(
self,
):
class Model(torch.nn.Module):
def forward(self, x, y):
return x + y

model = Model()
model.eval()

example_inputs = (torch.tensor(1.0), torch.tensor(2.0))
aten = torch.export.export(model, example_inputs)

# Check that the input rank is 0
for node in aten.graph.nodes:
if node.op == "placeholder":
self.assertTrue(node.meta["val"].shape == ())

edge = to_edge(aten).transform([Rank0ToRank1Pass()])

# Check that the input rank is 1
for node in edge.exported_program().graph.nodes:
if node.op == "placeholder":
self.assertTrue(node.meta["val"].shape == (1, 1))
Loading