Skip to content

Commit 7e2b4fa

Browse files
angelayifacebook-github-bot
authored andcommitted
Turn on python dispatcher for EdgeOpArgValidator (#3809)
Summary: Possibly fixes #3659 We need to enable the python dispatcher so that expand_copy and view_copy will go through the correct meta kernels Differential Revision: D58091304
1 parent 13ba3a7 commit 7e2b4fa

File tree

2 files changed

+36
-2
lines changed

2 files changed

+36
-2
lines changed

exir/program/test/test_program.py

+33-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from executorch.extension.pybindings.portable_lib import (
3333
_load_for_executorch_from_buffer,
3434
)
35-
from torch.export import export, ExportedProgram
35+
from torch.export import Dim, export, ExportedProgram
3636
from torch.export._trace import _export
3737

3838
from torch.library import impl, Library
@@ -272,6 +272,38 @@ def test_edge_manager_transform(self):
272272
original_res, # x * y + x
273273
)
274274

275+
def test_issue_3659(self):
276+
277+
class Mul(torch.nn.Module):
278+
def __init__(self):
279+
super(Mul, self).__init__()
280+
281+
def forward(self, x: torch.Tensor, y: torch.Tensor):
282+
return torch.matmul(x, y)
283+
284+
def get_eager_model(self) -> torch.nn.Module:
285+
return self
286+
287+
def get_example_inputs(self):
288+
return (torch.randn(1, 3, 10), torch.randn(1, 10, 3))
289+
290+
def get_dynamic_shapes(self):
291+
dim1_x = Dim("Dot_dim1_x", min=2, max=100)
292+
dim2_x = Dim("Dot_dim2_x", min=2, max=100)
293+
return {"x": {1: dim1_x, 2: dim2_x}, "y": {1: dim2_x, 2: dim1_x}}
294+
295+
model = Mul()
296+
ep = torch.export.export(
297+
model, model.get_example_inputs(), dynamic_shapes=model.get_dynamic_shapes()
298+
)
299+
300+
to_edge(
301+
ep,
302+
compile_config=EdgeCompileConfig(
303+
_check_ir_validity=True,
304+
),
305+
)
306+
275307
def test_transform_dict_api(self):
276308
edge_manager = to_edge(get_exported_programs(), get_config_methods())
277309

exir/verification/verifier.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
EdgeOpArgValidator,
1919
RunHigherOrderOperatorError,
2020
)
21+
from torch._dispatch.python import enable_python_dispatcher
2122

2223
from torch._export.verifier import SpecViolationError, Verifier
2324
from torch._ops import OpOverload
@@ -147,7 +148,8 @@ def _check_tensor_args_matching_op_allowed_dtype(gm: GraphModule) -> None:
147148
validator = EdgeOpArgValidator(gm)
148149
inputs = _get_inputs(gm)
149150
try:
150-
validator.run(*inputs)
151+
with enable_python_dispatcher():
152+
validator.run(*inputs)
151153
except RunHigherOrderOperatorError:
152154
# NB: ignore higher order operator in the graph.
153155
# If we lower a graph module to delegate and then compose it with some other graph module, retrace it,

0 commit comments

Comments
 (0)