Skip to content

Commit fd41791

Browse files
committed
Turn on python dispatcher for EdgeOpArgValidator (#3809)
Summary: Pull Request resolved: #3809 Possibly fixes #3659 We need to enable the python dispatcher so that expand_copy and view_copy will go through the correct meta kernels Reviewed By: larryliu0820 Differential Revision: D58091304 fbshipit-source-id: f8907ee130720b01c629d55f222eb5a7e63a34bd (cherry picked from commit ab6f177)
1 parent 50d1da2 commit fd41791

File tree

2 files changed

+37
-3
lines changed

2 files changed

+37
-3
lines changed

exir/program/test/test_program.py

+34-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from typing import Any, Dict
1111

1212
import torch
13-
from executorch.exir import ExecutorchBackendConfig
13+
from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig
1414
from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo
1515
from executorch.exir.dialects._ops import ops as exir_ops
1616
from executorch.exir.error import ExportError
@@ -26,7 +26,7 @@
2626
from executorch.extension.pybindings.portable_lib import (
2727
_load_for_executorch_from_buffer,
2828
)
29-
from torch.export import export, ExportedProgram
29+
from torch.export import Dim, export, ExportedProgram
3030

3131
from torch.library import impl, Library
3232

@@ -225,6 +225,38 @@ def test_edge_manager_transform(self):
225225
original_res, # x * y + x
226226
)
227227

228+
def test_issue_3659(self):
229+
230+
class Mul(torch.nn.Module):
231+
def __init__(self):
232+
super(Mul, self).__init__()
233+
234+
def forward(self, x: torch.Tensor, y: torch.Tensor):
235+
return torch.matmul(x, y)
236+
237+
def get_eager_model(self) -> torch.nn.Module:
238+
return self
239+
240+
def get_example_inputs(self):
241+
return (torch.randn(1, 3, 10), torch.randn(1, 10, 3))
242+
243+
def get_dynamic_shapes(self):
244+
dim1_x = Dim("Dot_dim1_x", min=2, max=100)
245+
dim2_x = Dim("Dot_dim2_x", min=2, max=100)
246+
return {"x": {1: dim1_x, 2: dim2_x}, "y": {1: dim2_x, 2: dim1_x}}
247+
248+
model = Mul()
249+
ep = torch.export.export(
250+
model, model.get_example_inputs(), dynamic_shapes=model.get_dynamic_shapes()
251+
)
252+
253+
to_edge(
254+
ep,
255+
compile_config=EdgeCompileConfig(
256+
_check_ir_validity=True,
257+
),
258+
)
259+
228260
def test_transform_dict_api(self):
229261
edge_manager = to_edge(get_exported_programs(), get_config_methods())
230262

exir/verification/verifier.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
EdgeOpArgValidator,
1818
RunHigherOrderOperatorError,
1919
)
20+
from torch._dispatch.python import enable_python_dispatcher
2021

2122
from torch._export.verifier import SpecViolationError, Verifier
2223
from torch._ops import OpOverload
@@ -119,7 +120,8 @@ def _check_tensor_args_matching_op_allowed_dtype(gm: GraphModule) -> None:
119120
validator = EdgeOpArgValidator(gm)
120121
inputs = _get_inputs(gm)
121122
try:
122-
validator.run(*inputs)
123+
with enable_python_dispatcher():
124+
validator.run(*inputs)
123125
except RunHigherOrderOperatorError:
124126
# NB: ignore higher order operator in the graph.
125127
# If we lower a graph module to delegate and then compose it with some other graph module, retrace it,

0 commit comments

Comments
 (0)