Skip to content

[Not for Commit] Example use of new dim_order api #8289

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions exir/tests/test_memory_format_ops_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def test_op_dim_order_update(self) -> None:
)

def test_op_dim_order_propagation(self) -> None:
print("test_op_dim_order_propagation: unambiguous path")
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
self,
MemoryFormatTestSet(
Expand All @@ -126,6 +127,24 @@ def test_op_dim_order_propagation(self) -> None:
),
)

print("test_op_dim_order_propagation: ambiguous path")
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
self,
MemoryFormatTestSet(
module=PropagateToCopyChannalsLastModule().eval(),
op=torch.ops.aten._to_copy.default,
sample_input=(
torch.rand_like(
torch.zeros([2, 1, 2, 2]),
dtype=torch.float32,
memory_format=torch.contiguous_format,
),
),
target_memory_format=torch.channels_last,
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
),
)

# Only test dim order replacement result in lean mode test.
# This test is irrelevant with operator mode.
def test_dim_order_replacement(self) -> None:
Expand Down
52 changes: 52 additions & 0 deletions exir/tests/test_memory_format_ops_pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,13 @@
is_channel_last_dim_order,
is_contiguous_dim_order,
)
from executorch.exir.pass_base import ExportPass

from exir.passes.memory_format_ops_pass import MemoryFormatOpsPass

from torch.export import export

from torch.fx.passes.infra.pass_manager import PassManager
from torch.testing import FileCheck
from torch.utils._pytree import tree_flatten

Expand Down Expand Up @@ -99,6 +104,50 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return t1 * t2


def assert_unambiguous_dim_order(gm):
# This is just an example, you can add your own pass or passes.
class ExampleNOPPass(ExportPass):
"""
Does nothing!
"""

def call_operator(self, op, args, kwargs, meta):
return super().call_operator(
op,
args,
kwargs,
meta,
)

# This is an example of how one can detect ambiguous dim_order anywhere in the graph.
# You can be surgical and only detect it in the nodes you are interested in or something else.
def detect_ambiguity(gm):
"""
Check every node's output tensor dim_order and raise if it is ambiguous for a list of formats.
"""
for node in gm.graph.nodes:
if node.op == "call_function":
tensor = node.meta["val"]
# Let's make sure dim_order is not ambiguous, raise otherwise.
# This is raising because we can't do anything about it.
# The right course of follow up action is to ask user to try with a different example input.
print(f"node: {node}, shape: {tensor.shape}, ", end="")

try:
dim_order = tensor.dim_order(
ambiguity_check=[torch.contiguous_format, torch.channels_last]
)
print(f"dim_order: {dim_order}")
except Exception as e:
print("")
raise RuntimeError(e)

# any pass or passes, just using MemoryFormatOpsPass as an example
dim_order_pass_manager = PassManager(passes=[ExampleNOPPass()])
dim_order_pass_manager.add_checks(detect_ambiguity)
dim_order_pass_manager(gm)


class MemoryFormatOpsPassTestUtils:
@staticmethod
def memory_format_test_runner(
Expand All @@ -121,6 +170,9 @@ def memory_format_test_runner(
before, compile_config=EdgeCompileConfig(_skip_dim_order=False)
)

# Just as an example
assert_unambiguous_dim_order(epm.exported_program().graph_module)

# check memory format ops, if needed
if test_set.op_level_check:
aten_op_str, edge_op_str = MemoryFormatOps2Str[test_set.op]
Expand Down
Loading