From 7e2b4fa12efc9cb4cbbd315df431b7fae29cc589 Mon Sep 17 00:00:00 2001 From: Angela Yi Date: Mon, 3 Jun 2024 09:54:55 -0700 Subject: [PATCH] Turn on python dispatcher for EdgeOpArgValidator (#3809) Summary: Possibly fixes https://github.com/pytorch/executorch/issues/3659 We need to enable the python dispatcher so that expand_copy and view_copy will go through the correct meta kernels Differential Revision: D58091304 --- exir/program/test/test_program.py | 34 ++++++++++++++++++++++++++++++- exir/verification/verifier.py | 4 +++- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/exir/program/test/test_program.py b/exir/program/test/test_program.py index b189e6b0f06..73057feec3d 100644 --- a/exir/program/test/test_program.py +++ b/exir/program/test/test_program.py @@ -32,7 +32,7 @@ from executorch.extension.pybindings.portable_lib import ( _load_for_executorch_from_buffer, ) -from torch.export import export, ExportedProgram +from torch.export import Dim, export, ExportedProgram from torch.export._trace import _export from torch.library import impl, Library @@ -272,6 +272,38 @@ def test_edge_manager_transform(self): original_res, # x * y + x ) + def test_issue_3659(self): + + class Mul(torch.nn.Module): + def __init__(self): + super(Mul, self).__init__() + + def forward(self, x: torch.Tensor, y: torch.Tensor): + return torch.matmul(x, y) + + def get_eager_model(self) -> torch.nn.Module: + return self + + def get_example_inputs(self): + return (torch.randn(1, 3, 10), torch.randn(1, 10, 3)) + + def get_dynamic_shapes(self): + dim1_x = Dim("Dot_dim1_x", min=2, max=100) + dim2_x = Dim("Dot_dim2_x", min=2, max=100) + return {"x": {1: dim1_x, 2: dim2_x}, "y": {1: dim2_x, 2: dim1_x}} + + model = Mul() + ep = torch.export.export( + model, model.get_example_inputs(), dynamic_shapes=model.get_dynamic_shapes() + ) + + to_edge( + ep, + compile_config=EdgeCompileConfig( + _check_ir_validity=True, + ), + ) + def test_transform_dict_api(self): edge_manager = to_edge(get_exported_programs(), get_config_methods()) diff --git a/exir/verification/verifier.py b/exir/verification/verifier.py index 43003c9253a..5e4c24b1f6a 100644 --- a/exir/verification/verifier.py +++ b/exir/verification/verifier.py @@ -18,6 +18,7 @@ EdgeOpArgValidator, RunHigherOrderOperatorError, ) +from torch._dispatch.python import enable_python_dispatcher from torch._export.verifier import SpecViolationError, Verifier from torch._ops import OpOverload @@ -147,7 +148,8 @@ def _check_tensor_args_matching_op_allowed_dtype(gm: GraphModule) -> None: validator = EdgeOpArgValidator(gm) inputs = _get_inputs(gm) try: - validator.run(*inputs) + with enable_python_dispatcher(): + validator.run(*inputs) except RunHigherOrderOperatorError: # NB: ignore higher order operator in the graph. # If we lower a graph module to delegate and then compose it with some other graph module, retrace it,