Skip to content

Commit ec7367c

Browse files
committed
Update base for Update on "Bump ExecuTorch's PyTorch nightly pin to dev20241121"
Require at least 11/18 to unblock #7040 . Differential Revision: [D66398425](https://our.internmc.facebook.com/intern/diff/D66398425/) [ghstack-poisoned]
2 parents 71801b1 + ddec0c7 commit ec7367c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+1798
-479
lines changed

backends/arm/TARGETS

+11
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,14 @@ python_library(
110110
"//executorch/backends/arm/operators:node_visitor",
111111
],
112112
)
113+
114+
python_library(
115+
name = "arm_model_evaluator",
116+
src = [
117+
"util/arm_model_evaluator.py",
118+
],
119+
typing = True,
120+
deps = [
121+
"//caffe2:torch",
122+
]
123+
)

backends/arm/_passes/arm_pass_manager.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@
2929
DecomposeSoftmaxesPass,
3030
)
3131
from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass
32-
from executorch.backends.arm._passes.insert_squeeze_after_sum_pass import (
33-
InsertSqueezeAfterSumPass,
32+
from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import (
33+
KeepDimsFalseToSqueezePass,
3434
)
3535
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
3636
from executorch.backends.arm._passes.meandim_to_averagepool_pass import (
@@ -71,7 +71,7 @@ def transform_to_backend_pipeline(
7171
self.add_pass(DecomposeMeanDimPass())
7272
self.add_pass(MatchArgRanksPass(exported_program))
7373
self.add_pass(DecomposeDivPass())
74-
self.add_pass(InsertSqueezeAfterSumPass())
74+
self.add_pass(KeepDimsFalseToSqueezePass())
7575
self.add_pass(ConvertSplitToSlicePass())
7676
self.add_pass(Conv1dUnsqueezePass(exported_program))
7777
self.add_pass(DecomposeSoftmaxesPass())

backends/arm/_passes/arm_pass_utils.py

+58
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
# pyre-unsafe
99

10+
from inspect import isclass
1011
from typing import Optional
1112

1213
import torch
@@ -133,3 +134,60 @@ def get_first_fake_tensor(node: torch.fx.Node) -> FakeTensor:
133134
fake_tensor, FakeTensor
134135
), f'Found {fake_tensor} in meta["val"] of {node}, expected to find FakeTensor.'
135136
return fake_tensor
137+
138+
139+
def get_node_arg(args: list | dict, key: int | str | type, default_value=None):
140+
"""
141+
Help-function for getting a value from node.args/ kwargs, three cases:
142+
1. By position in node.args - Returns arg at given position or default_value if index is one out of bounds
143+
2. By key in node.kwargs - Returns kwarg with given key or default_value if it deos not exist
144+
3. By type in node.args - Returns first arg of args of given type. Useful for cases where arg postions may differ but types are unique.
145+
"""
146+
if isinstance(key, int):
147+
if 0 <= key < len(args):
148+
return args[key]
149+
elif key == len(args):
150+
if default_value is not None:
151+
return default_value
152+
else:
153+
raise RuntimeError(f"No defult value given for index {key}")
154+
else:
155+
raise RuntimeError(
156+
f"Out of bounds index {key} for getting value in args (of size {len(args)})"
157+
)
158+
elif isinstance(key, str):
159+
return args.get(key, default_value)
160+
elif isclass(key):
161+
for arg in args:
162+
if isinstance(arg, key):
163+
return arg
164+
if default_value is not None:
165+
return default_value
166+
else:
167+
raise RuntimeError(f"No arg of type {key}")
168+
else:
169+
raise RuntimeError("Invalid type")
170+
171+
172+
def set_node_arg(node: torch.fx.Node, i: int | str, value):
173+
"""
174+
Help-function for setting a value in node.args/ kwargs. If the index is one larger than the list size, the value is instead appended to the list.
175+
"""
176+
if isinstance(i, int):
177+
if 0 <= i < len(node.args):
178+
args = list(node.args)
179+
args[i] = value
180+
node.args = tuple(args)
181+
return
182+
elif i == len(node.args):
183+
node.args = node.args + (value,)
184+
else:
185+
raise RuntimeError(
186+
f"Out of bounds index {i} for setting value in {node} args (of size {len(node.args)})"
187+
)
188+
elif isinstance(i, str):
189+
kwargs = dict(node.kwargs)
190+
kwargs[i] = value
191+
node.kwargs = kwargs
192+
else:
193+
raise RuntimeError("Invalid type")

backends/arm/_passes/decompose_meandim_pass.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# pyre-unsafe
88

99
import torch
10+
from executorch.backends.arm._passes.arm_pass_utils import get_node_arg
1011
from executorch.exir.dialects._ops import ops as exir_ops
1112
from executorch.exir.pass_base import ExportPass
1213

@@ -42,16 +43,16 @@ def call_operator(self, op, args, kwargs, meta):
4243
if op not in (exir_ops.edge.aten.mean.dim, torch.ops.aten.mean.dim):
4344
return super().call_operator(op, args, kwargs, meta)
4445

45-
x = args[0]
46-
dim = args[1]
47-
keepdim = args[2] if len(args) > 2 else False
48-
if not keepdim:
49-
return super().call_operator(op, args, kwargs, meta)
50-
# if keepdim == True and dim == [-1, -2], mean.dim can be
46+
x = get_node_arg(args, 0)
47+
dim = get_node_arg(args, 1)
48+
keepdim = get_node_arg(args, 2, False)
49+
50+
# if dim == [-1, -2], mean.dim can be
5151
# decomposed to avg_pool2d. This is handled by ConvertMeanDimToAveragePool.
5252
if dim == [-1, -2]:
5353
# Simply return the mean.dim operator for future decomposition.
5454
return super().call_operator(op, args, kwargs, meta)
55+
5556
shape = meta["val"].size()
5657
dtype = meta["val"].dtype
5758
input_shape = x.data.size()

backends/arm/_passes/decompose_var_pass.py

+16-11
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99

1010
import torch
11+
from executorch.backends.arm._passes.arm_pass_utils import get_node_arg
1112
from executorch.exir.dialects._ops import ops as exir_ops
1213
from executorch.exir.pass_base import ExportPass
1314

@@ -53,26 +54,30 @@ def call_operator(self, op, args, kwargs, meta):
5354
torch.ops.aten.var.dim,
5455
):
5556
return super().call_operator(op, args, kwargs, meta)
56-
shape = meta["val"].size()
57+
58+
x = args[0]
59+
input_shape = x.data.size()
60+
shape = list(meta["val"].size())
61+
if shape == []:
62+
shape = [1 for _ in input_shape]
63+
5764
dtype = meta["val"].dtype
58-
dim = args[1] if len(args) > 1 else list(range(len(shape)))
65+
# Get dim from args based on argument type
66+
dim = get_node_arg(args, key=list, default_value=list(range(len(shape))))
67+
5968
if op == torch.ops.aten.var.dim:
60-
correction = args[-2]
61-
keepdim = args[-1]
69+
keepdim = get_node_arg(args, bool, False)
70+
correction = get_node_arg(args, int, 1)
6271
else:
63-
correction = kwargs["correction"]
64-
keepdim = kwargs.get("keepdim", False)
65-
if not keepdim:
66-
return super().call_operator(op, args, kwargs, meta)
72+
correction = get_node_arg(kwargs, "correction", 1)
73+
keepdim = get_node_arg(kwargs, "keepdim", False)
6774

68-
x = args[0]
69-
input_shape = x.data.size()
7075
N = 1
7176
for d in dim:
7277
N *= input_shape[d]
7378

7479
mean_op, diff_op, mul_op, sum_op, full_op = get_var_decomposition(op)
75-
mean = super().call_operator(mean_op, (x, dim, keepdim), {}, meta)
80+
mean = super().call_operator(mean_op, (x, dim, True), {}, meta)
7681
diff = super().call_operator(diff_op, (x, mean), {}, meta)
7782
squared_diff = super().call_operator(mul_op, (diff, diff), {}, meta)
7883
sum = super().call_operator(sum_op, (squared_diff, dim, keepdim), {}, meta)

backends/arm/_passes/insert_squeeze_after_sum_pass.py renamed to backends/arm/_passes/keep_dims_false_to_squeeze_pass.py

+35-7
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,18 @@
1010

1111
import torch
1212
import torch.fx
13-
from executorch.backends.arm._passes.arm_pass_utils import create_node
13+
from executorch.backends.arm._passes.arm_pass_utils import (
14+
create_node,
15+
get_node_arg,
16+
set_node_arg,
17+
)
1418
from executorch.exir.dialects._ops import ops as exir_ops
1519
from executorch.exir.pass_base import ExportPass, PassResult
1620

1721

18-
class InsertSqueezeAfterSumPass(ExportPass):
22+
class KeepDimsFalseToSqueezePass(ExportPass):
1923
"""
20-
In Pytorch, the default behaviour of Tensor.sum is to squeeze
24+
In Pytorch, the default behaviour of for example Tensor.sum is to squeeze
2125
the dimension that is summed (keep_dim = False).
2226
However, in TOSA, REDUCE_SUM always preserves the
2327
rank of the input (keep_dim = True).
@@ -31,28 +35,52 @@ class InsertSqueezeAfterSumPass(ExportPass):
3135
squeeze(dim = dims)
3236
"""
3337

38+
# CURRENTLY NOT HANDLED OPS
39+
# exir_ops.edge.aten.amax,
40+
# exir_ops.edge.aten.amin,
41+
# exir_ops.edge.aten.any.dim,
42+
# exir_ops.edge.aten.any.dims,
43+
# exir_ops.edge.aten.argmax,
44+
# exir_ops.edge.aten.argmin,
45+
# exir_ops.edge.aten.max.dim,
46+
# exir_ops.edge.aten.min.dim,
47+
# exir_ops.edge.aten.prod.dim_int,
48+
49+
# HANDLED OPS
50+
# exir_ops.edge.aten.sum.dim_IntList
51+
# exir_ops.edge.aten.var.correction (decomposed in decompose_var_pass)
52+
# exir_ops.edge.aten.var.dim (decomposed in decompose_var_pass)
53+
# exir_ops.edge.aten.mean.dim (decomposed in decompose_meandim_pass)
54+
3455
def call(self, graph_module: torch.fx.GraphModule):
3556
for node in graph_module.graph.nodes:
57+
keep_dim_index = None
58+
3659
if node.op != "call_function":
3760
continue
38-
if node.target != exir_ops.edge.aten.sum.dim_IntList:
61+
if node.target == exir_ops.edge.aten.sum.dim_IntList:
62+
keep_dim_index = 2
63+
else:
3964
continue
65+
4066
sum_node = cast(torch.fx.Node, node)
41-
keep_dim = cast(bool, sum_node.args[2] if len(sum_node.args) > 2 else False)
67+
keep_dim = get_node_arg(sum_node.args, keep_dim_index, False)
68+
4269
if keep_dim:
4370
continue
4471

45-
dim_list = cast(list[int], sum_node.args[1])
72+
dim_list = get_node_arg(sum_node.args, 1, [0])
4673

4774
# Add keep_dim = True arg to sum node.
48-
sum_node.args = sum_node.args[0:2] + (True,)
75+
set_node_arg(sum_node, 2, True)
4976

5077
with graph_module.graph.inserting_after(sum_node):
5178
squeeze_node = create_node(
5279
graph_module.graph, exir_ops.edge.aten.squeeze_copy.dims, ()
5380
)
5481
sum_node.replace_all_uses_with(squeeze_node)
5582
squeeze_node.args = (sum_node, dim_list)
83+
5684
graph_module.graph.eliminate_dead_code()
5785
graph_module.recompile()
5886
graph_module = super().call(graph_module).graph_module

backends/arm/operator_support/__init__.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,4 @@
55

66
# pyre-unsafe
77

8-
from . import ( # noqa
9-
mean_dim_support,
10-
right_shift_support,
11-
tosa_supported_operators,
12-
var_correction_support,
13-
)
8+
from . import right_shift_support, to_copy_support, tosa_supported_operators # noqa

backends/arm/operator_support/mean_dim_support.py

-33
This file was deleted.

0 commit comments

Comments
 (0)