Skip to content

Commit 9f58102

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Fix edge dialect verifier bugs
Summary: ## Problem If we lower a graph module to delegate and then compose it with some other graph module, retrace it, if we also turn on edge ops and validator (`_use_edge_ops=True`, `_check_ir_validity=False`), validator will error out because it doesn't know how to handle lowered module. ## Solution Fix this by ignoring all HigherOrderOps: catch the exception and pass. Reviewed By: JacobSzwejbka Differential Revision: D46608636 fbshipit-source-id: 91fc7a887702f40c25a4595717210e8e32c3700a
1 parent d368d66 commit 9f58102

File tree

3 files changed

+48
-16
lines changed

3 files changed

+48
-16
lines changed

exir/tests/test_arg_validator.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from executorch import exir
88
from executorch.exir import EdgeCompileConfig
99
from executorch.exir.dialects._ops import ops
10-
from executorch.exir.graph_module import get_exir_meta
10+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
1111
from executorch.exir.verification.arg_validator import EdgeOpArgValidator
1212

1313

@@ -55,12 +55,13 @@ def forward(self, x):
5555
validator = EdgeOpArgValidator(egm)
5656
validator.run(*inputs)
5757
self.assertEqual(len(validator.violating_ops), 1)
58+
key: EdgeOpOverload = next(iter(validator.violating_ops))
5859
self.assertEqual(
59-
validator.violating_ops[0][0].name(),
60+
key.name(),
6061
ops.edge.aten._log_softmax.default.name(),
6162
)
6263
self.assertDictEqual(
63-
validator.violating_ops[0][1],
64+
validator.violating_ops[key],
6465
{
6566
"self": torch.bfloat16,
6667
"__ret_0": torch.bfloat16,

exir/verification/arg_validator.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,24 @@
11
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2-
from typing import Any, Dict, List, Sequence, Tuple
2+
from collections import defaultdict
3+
from typing import Any, Dict, Sequence, Tuple
34

45
import torch
56
from executorch.exir.dialects.edge._ops import EdgeDialectFunctionSchema, EdgeOpOverload
67
from executorch.exir.emit._emitter import _Argument, _Target
78
from executorch.exir.error import ExportError, InternalError
9+
from torch._ops import HigherOrderOperator
10+
11+
12+
class RunHigherOrderOperatorError(Exception):
13+
"""
14+
Raised when an we try to run delegate or other HigherOrderOperator in a graph module.
15+
E.g., %executorch_call_delegate : [#users=1] = call_function[
16+
target=torch.ops.executorch_call_delegate](args = (%lowered_module_0, %arg0_1), kwargs = {})
17+
"""
18+
19+
def __init__(self, message: str) -> None:
20+
super().__init__(message)
21+
822

923
# pyre-ignore[13]: Attribute `node` is never initialized.
1024
class EdgeOpArgValidator(torch.fx.Interpreter):
@@ -18,14 +32,16 @@ class EdgeOpArgValidator(torch.fx.Interpreter):
1832

1933
def __init__(self, graph_module: torch.fx.GraphModule) -> None:
2034
super().__init__(graph_module)
21-
self.violating_ops: List[Tuple[EdgeOpOverload, Dict[str, torch.dtype]]] = []
35+
self.violating_ops: Dict[EdgeOpOverload, Dict[str, torch.dtype]] = defaultdict(
36+
dict
37+
)
2238

2339
def run_node(self, n: torch.fx.Node) -> None:
2440
self.node = n
2541
try:
2642
ret = super().run_node(n)
2743
except Exception as e:
28-
if isinstance(e, (InternalError, ExportError)):
44+
if isinstance(e, (InternalError, ExportError, RunHigherOrderOperatorError)):
2945
raise e
3046
else:
3147
raise InternalError(str(e)) from e
@@ -40,7 +56,9 @@ def call_function(
4056
if not isinstance(target, EdgeOpOverload) or not isinstance(
4157
target._schema, EdgeDialectFunctionSchema
4258
):
43-
return False
59+
if isinstance(target, HigherOrderOperator):
60+
raise RunHigherOrderOperatorError("Can't run delegate")
61+
return super().call_function(target, args, kwargs)
4462
tensor_arg_types: Dict[str, torch.dtype] = {}
4563
for i, schema_arg in enumerate(target._schema.arguments):
4664
if not isinstance(schema_arg.type, torch.TensorType):
@@ -60,15 +78,15 @@ def call_function(
6078
kernel_rets if isinstance(kernel_rets, Sequence) else [kernel_rets]
6179
)
6280
for schema_ret in target._schema.returns:
63-
if isinstance(schema_ret.type, torch.TensorType):
64-
name = schema_ret.name if schema_ret.name else f"__ret_{ret_index}"
65-
kernel_ret = next(ret_iter)
66-
if not isinstance(kernel_ret, torch.Tensor):
67-
continue
81+
name = schema_ret.name if schema_ret.name else f"__ret_{ret_index}"
82+
kernel_ret = next(ret_iter)
83+
if isinstance(schema_ret.type, torch.TensorType) and isinstance(
84+
kernel_ret, torch.Tensor
85+
):
6886
tensor_arg_types[name] = kernel_ret.dtype
6987
ret_index += 1
7088

7189
valid = target._schema.dtype_constraint.validate(tensor_arg_types)
7290
if not valid:
73-
self.violating_ops.append((target, tensor_arg_types))
74-
return valid
91+
self.violating_ops[target] = tensor_arg_types
92+
return super().call_function(target, args, kwargs)

exir/verification/verifier.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
from executorch.exir.delegate import executorch_call_delegate
55
from executorch.exir.dialects.edge._ops import EdgeOpOverload
66
from executorch.exir.error import ExportError, ExportErrorType
7-
from executorch.exir.verification.arg_validator import EdgeOpArgValidator
7+
from executorch.exir.verification.arg_validator import (
8+
EdgeOpArgValidator,
9+
RunHigherOrderOperatorError,
10+
)
811

912
from torch._export.verifier import (
1013
_check_has_fake_tensor,
@@ -70,7 +73,17 @@ def extract_input(node: torch.fx.Node) -> Optional[FakeTensor]:
7073
def _check_tensor_args_matching_op_allowed_dtype(gm: GraphModule) -> None:
7174
validator = EdgeOpArgValidator(gm)
7275
inputs = _get_inputs(gm)
73-
validator.run(*inputs)
76+
try:
77+
validator.run(*inputs)
78+
except RunHigherOrderOperatorError:
79+
# NB: ignore higher order operator in the graph.
80+
# If we lower a graph module to delegate and then compose it with some other graph module, retrace it,
81+
# if we also turn on edge ops and validator (_use_edge_ops=True, _check_ir_validity=True), we will run
82+
# into RunHigherOrderOperatorError. The only thing we can do right now is to ignore this error, since
83+
# by definition it's still a valid Edge dialect. This is not ideal because it ignores possible invalidity
84+
# later in the graph.
85+
return
86+
7487
if validator.violating_ops:
7588
raise SpecViolationError(
7689
f"These operators are taking Tensor inputs with mismatched dtypes: {validator.violating_ops}"

0 commit comments

Comments
 (0)