1
1
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2
- from typing import Any , Dict , List , Sequence , Tuple
2
+ from collections import defaultdict
3
+ from typing import Any , Dict , Sequence , Tuple
3
4
4
5
import torch
5
6
from executorch .exir .dialects .edge ._ops import EdgeDialectFunctionSchema , EdgeOpOverload
6
7
from executorch .exir .emit ._emitter import _Argument , _Target
7
8
from executorch .exir .error import ExportError , InternalError
9
+ from torch ._ops import HigherOrderOperator
10
+
11
+
12
+ class RunHigherOrderOperatorError (Exception ):
13
+ """
14
+ Raised when an we try to run delegate or other HigherOrderOperator in a graph module.
15
+ E.g., %executorch_call_delegate : [#users=1] = call_function[
16
+ target=torch.ops.executorch_call_delegate](args = (%lowered_module_0, %arg0_1), kwargs = {})
17
+ """
18
+
19
+ def __init__ (self , message : str ) -> None :
20
+ super ().__init__ (message )
21
+
8
22
9
23
# pyre-ignore[13]: Attribute `node` is never initialized.
10
24
class EdgeOpArgValidator (torch .fx .Interpreter ):
@@ -18,14 +32,16 @@ class EdgeOpArgValidator(torch.fx.Interpreter):
18
32
19
33
def __init__ (self , graph_module : torch .fx .GraphModule ) -> None :
20
34
super ().__init__ (graph_module )
21
- self .violating_ops : List [Tuple [EdgeOpOverload , Dict [str , torch .dtype ]]] = []
35
+ self .violating_ops : Dict [EdgeOpOverload , Dict [str , torch .dtype ]] = defaultdict (
36
+ dict
37
+ )
22
38
23
39
def run_node (self , n : torch .fx .Node ) -> None :
24
40
self .node = n
25
41
try :
26
42
ret = super ().run_node (n )
27
43
except Exception as e :
28
- if isinstance (e , (InternalError , ExportError )):
44
+ if isinstance (e , (InternalError , ExportError , RunHigherOrderOperatorError )):
29
45
raise e
30
46
else :
31
47
raise InternalError (str (e )) from e
@@ -40,7 +56,9 @@ def call_function(
40
56
if not isinstance (target , EdgeOpOverload ) or not isinstance (
41
57
target ._schema , EdgeDialectFunctionSchema
42
58
):
43
- return False
59
+ if isinstance (target , HigherOrderOperator ):
60
+ raise RunHigherOrderOperatorError ("Can't run delegate" )
61
+ return super ().call_function (target , args , kwargs )
44
62
tensor_arg_types : Dict [str , torch .dtype ] = {}
45
63
for i , schema_arg in enumerate (target ._schema .arguments ):
46
64
if not isinstance (schema_arg .type , torch .TensorType ):
@@ -60,15 +78,15 @@ def call_function(
60
78
kernel_rets if isinstance (kernel_rets , Sequence ) else [kernel_rets ]
61
79
)
62
80
for schema_ret in target ._schema .returns :
63
- if isinstance ( schema_ret .type , torch . TensorType ):
64
- name = schema_ret . name if schema_ret . name else f"__ret_ { ret_index } "
65
- kernel_ret = next ( ret_iter )
66
- if not isinstance ( kernel_ret , torch .Tensor ):
67
- continue
81
+ name = schema_ret .name if schema_ret . name else f"__ret_ { ret_index } "
82
+ kernel_ret = next ( ret_iter )
83
+ if isinstance ( schema_ret . type , torch . TensorType ) and isinstance (
84
+ kernel_ret , torch .Tensor
85
+ ):
68
86
tensor_arg_types [name ] = kernel_ret .dtype
69
87
ret_index += 1
70
88
71
89
valid = target ._schema .dtype_constraint .validate (tensor_arg_types )
72
90
if not valid :
73
- self .violating_ops . append (( target , tensor_arg_types ))
74
- return valid
91
+ self .violating_ops [ target ] = tensor_arg_types
92
+ return super (). call_function ( target , args , kwargs )
0 commit comments