Skip to content

Commit a939888

Browse files
author
Nathanael See
committed
Update on "[ET-VK] fix index error bug in ViewCopyToSqueezeUnsqueezePass"
See T214560872 #8226 added the pass to the partition preprocess pass list, so now it runs on all exports. This uncovered a bug in the squeeze dims finding function in the mobilenet test case. Differential Revision: [D69254910](https://our.internmc.facebook.com/intern/diff/D69254910/) [ghstack-poisoned]
2 parents 24483d2 + ae73d03 commit a939888

36 files changed

+1152
-62
lines changed

.ci/scripts/gather_benchmark_configs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ def set_output(name: str, val: Any) -> None:
238238
try:
239239
with open(github_output, "a") as env:
240240
env.write(f"{name}={val}\n")
241-
except PermissionError:
241+
except (PermissionError, FileNotFoundError):
242242
# Fall back to printing in case of permission error in unit tests
243243
print(f"::set-output name={name}::{val}")
244244

backends/arm/_passes/arm_pass_manager.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
123123
self.add_pass(FuseQuantizedActivationPass())
124124
self.add_pass(RemoveGetItemPass())
125125
self.add_pass(ConvertSplitToSlicePass())
126+
self.add_pass(FuseBatchnorm2DPass(exported_program))
126127
self.add_pass(ConvertMmToBmmPass())
127128
self.add_pass(DecomposeLinearPass())
128129
self.add_pass(DecomposeBatchNormPass())
@@ -132,7 +133,6 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
132133
self.add_pass(ConvertMeanDimToAveragePoolPass())
133134
self.add_pass(DecomposeDivPass())
134135
self.add_pass(DecomposeSoftmaxesPass())
135-
self.add_pass(FuseBatchnorm2DPass(exported_program))
136136

137137
self.add_pass(AnnotateDecomposedMatmulPass())
138138
self.add_pass(QuantizeOperatorArguments())
+9-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
77

8-
from . import right_shift_support, to_copy_support, tosa_supported_operators # noqa
8+
from . import ( # noqa
9+
convolution_support,
10+
pool_2d_support,
11+
reduce_sum_support,
12+
right_shift_support,
13+
to_copy_support,
14+
tosa_supported_operators,
15+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import cast
7+
8+
import torch
9+
import torch.fx as fx
10+
from executorch.backends.arm.operator_support.tosa_supported_operators import (
11+
register_tosa_support_check,
12+
SupportedTOSAOperatorCheck,
13+
)
14+
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
15+
from executorch.exir.dialects._ops import ops as exir_ops
16+
17+
18+
@register_tosa_support_check
19+
class ConvolutionSupported(SupportedTOSAOperatorCheck):
20+
targets = [exir_ops.edge.aten.convolution.default]
21+
22+
tosa_specs = [
23+
TosaSpecification.create_from_string("TOSA-0.80+BI"),
24+
TosaSpecification.create_from_string("TOSA-0.80+MI"),
25+
]
26+
27+
def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
28+
29+
# Not implemented
30+
transposed = cast(bool, node.args[6])
31+
output_padding = cast(list[int], node.args[7])
32+
if transposed:
33+
return False
34+
35+
for pad in output_padding:
36+
if pad != 0:
37+
return False
38+
39+
# Hardware specific constraints
40+
if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset):
41+
return True
42+
else:
43+
return self._is_node_supported_u55(node)
44+
45+
def _is_node_supported_u55(self, node: fx.Node):
46+
"""Hardware constraints for Ethos-U-55 case, Vela 4.2.0 (25.02 release)"""
47+
48+
shape_in = cast(torch.Tensor, node.all_input_nodes[0].meta["val"]).shape
49+
shape_out = node.meta["val"].shape
50+
kernel = cast(fx.Node, node.args[1]).meta["val"].shape
51+
group = cast(int, node.args[8])
52+
53+
C_in = shape_in[1]
54+
C_out = shape_out[1]
55+
if (C_in == group) and (C_out % C_in) == 0:
56+
# Depthwise convolution
57+
for dim in shape_in[1:]:
58+
if not 1 <= dim <= 65536:
59+
return False
60+
else:
61+
# Convolution
62+
if not 1 <= C_in <= 65536:
63+
return False
64+
65+
kernel_w = kernel[2]
66+
kernel_h = kernel[3] if len(kernel) > 3 else 1
67+
# Kernel condition misses constraint on sum of absolute weights
68+
if not 1 <= kernel_h <= 64 or not 1 <= kernel_w * kernel_h <= 4096:
69+
return False
70+
71+
if not self._stride_condition(node):
72+
return False
73+
74+
return True
75+
76+
def _stride_condition(self, node: fx.Node) -> bool:
77+
"""This condition is somewhat complex but boils down
78+
to not supporting stride > 3, unless we have some special conditions.
79+
This condition is a simplified, relaxed version of the hardware constraint,
80+
since the actual constraint requires information not available
81+
here (without a lot of work).
82+
83+
This means that we might accept ops that are not actually supported.
84+
"""
85+
strides = cast(list[int], node.args[3])
86+
has_padding = any(pad > 0 for pad in cast(list[int], node.args[4]))
87+
dilations = cast(list[int], node.args[5])
88+
if len(dilations) == 1:
89+
dilations = [dilations[0]] * 2
90+
if len(strides) == 1:
91+
strides = [strides[0]] * 2
92+
93+
for stride, dilation in zip(strides, dilations):
94+
stride_condition = 1 <= stride <= 3
95+
dilation_condition = (not has_padding) and (dilation == 1)
96+
if (not stride_condition) and (not dilation_condition):
97+
return False
98+
99+
return True
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import cast
7+
8+
import torch
9+
import torch.fx as fx
10+
from executorch.backends.arm.operator_support.tosa_supported_operators import (
11+
register_tosa_support_check,
12+
SupportedTOSAOperatorCheck,
13+
)
14+
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
15+
from executorch.exir.dialects._ops import ops as exir_ops
16+
17+
18+
def kernel_check(kernel: tuple[int, int]) -> bool:
19+
if not (1 <= kernel[0] * kernel[1] <= 65536):
20+
return False
21+
return 1 <= kernel[1] <= 256
22+
23+
24+
def stride_check(strides: tuple[int, int]) -> bool:
25+
return all(1 <= stride <= 3 for stride in strides)
26+
27+
28+
def dim_check(shape=torch.Size) -> bool:
29+
check = shape[0] == 1
30+
for dim in shape:
31+
check &= 1 <= dim <= 65536
32+
return check
33+
34+
35+
@register_tosa_support_check
36+
class AvgPool2dSupported(SupportedTOSAOperatorCheck):
37+
targets = [
38+
exir_ops.edge.aten.avg_pool2d.default,
39+
]
40+
41+
tosa_specs = [
42+
TosaSpecification.create_from_string("TOSA-0.80+BI"),
43+
TosaSpecification.create_from_string("TOSA-0.80+MI"),
44+
]
45+
46+
def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
47+
if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset):
48+
return True
49+
50+
# U55 case, Vela 4.2.0 (25.02 release)
51+
shape = cast(torch.Tensor, node.all_input_nodes[0].meta["val"]).shape
52+
kernel = cast(tuple[int, int], node.args[1])
53+
stride = cast(tuple[int, int], node.args[2])
54+
if len(node.args) > 3:
55+
# Padding case
56+
if not all(1 <= k <= 8 for k in kernel):
57+
return False
58+
else:
59+
if not kernel_check(kernel):
60+
return False
61+
62+
return dim_check(shape) and stride_check(stride)
63+
64+
65+
@register_tosa_support_check
66+
class MaxPool2dSupported(SupportedTOSAOperatorCheck):
67+
targets = [
68+
exir_ops.edge.aten.max_pool2d_with_indices.default,
69+
]
70+
71+
tosa_specs = [
72+
TosaSpecification.create_from_string("TOSA-0.80+BI"),
73+
TosaSpecification.create_from_string("TOSA-0.80+MI"),
74+
]
75+
76+
def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
77+
if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset):
78+
return True
79+
80+
# U55 case, Vela 4.2.0 (25.02 release)
81+
shape = cast(torch.Tensor, node.all_input_nodes[0].meta["val"]).shape
82+
kernel = cast(tuple[int, int], node.args[1])
83+
stride = cast(tuple[int, int], node.args[2])
84+
85+
return kernel_check(kernel) and dim_check(shape) and stride_check(stride)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import cast
7+
8+
import torch.fx as fx
9+
from executorch.backends.arm.operator_support.tosa_supported_operators import (
10+
register_tosa_support_check,
11+
SupportedTOSAOperatorCheck,
12+
)
13+
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
14+
from executorch.exir.dialects._ops import ops as exir_ops
15+
16+
17+
@register_tosa_support_check
18+
class SumSupported(SupportedTOSAOperatorCheck):
19+
targets = [exir_ops.edge.aten.sum.dim_IntList]
20+
21+
tosa_specs = [
22+
TosaSpecification.create_from_string("TOSA-0.80+BI"),
23+
TosaSpecification.create_from_string("TOSA-0.80+MI"),
24+
]
25+
26+
def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
27+
if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset):
28+
return True
29+
30+
# U55 case, Vela 4.2.0 (25.02 release)
31+
input_shape = node.all_input_nodes[0].meta["val"].shape
32+
dim_list = cast(list[int], node.args[1])
33+
dim_list = [dim % len(input_shape) for dim in dim_list]
34+
35+
for dim in dim_list:
36+
if not 1 <= input_shape[dim] <= 65536:
37+
return False
38+
39+
# We can't be certain of which dim is the last in memory yet,
40+
# Always go for stricter condition.
41+
pre_R_product = 1.0
42+
for length in input_shape[:dim]:
43+
pre_R_product *= length
44+
post_R_product = 1.0
45+
for length in input_shape[dim + 1 :]:
46+
post_R_product *= length
47+
if not 1 <= pre_R_product <= 65536:
48+
return False
49+
if not 1 <= post_R_product <= 65536:
50+
return False
51+
return True

backends/arm/operator_support/tosa_supported_operators.py

-4
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
8282
exir_ops.edge.aten.hardsigmoid.default,
8383
exir_ops.edge.aten.hardtanh.default,
8484
exir_ops.edge.aten.hardswish.default,
85-
exir_ops.edge.aten.convolution.default,
8685
exir_ops.edge.aten.div.Tensor,
8786
exir_ops.edge.aten.eq.Tensor,
8887
exir_ops.edge.aten.exp.default,
@@ -97,8 +96,6 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
9796
exir_ops.edge.aten.mul.Tensor,
9897
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
9998
exir_ops.edge.aten.native_layer_norm.default,
100-
exir_ops.edge.aten.avg_pool2d.default,
101-
exir_ops.edge.aten.max_pool2d_with_indices.default,
10299
exir_ops.edge.aten.sigmoid.default,
103100
exir_ops.edge.aten.mean.dim,
104101
exir_ops.edge.aten.mm.default,
@@ -113,7 +110,6 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
113110
exir_ops.edge.aten._log_softmax.default,
114111
exir_ops.edge.aten.slice_copy.Tensor,
115112
exir_ops.edge.aten.sub.Tensor,
116-
exir_ops.edge.aten.sum.dim_IntList,
117113
exir_ops.edge.aten.tanh.default,
118114
exir_ops.edge.aten.upsample_nearest2d.vec,
119115
exir_ops.edge.aten.var.correction,

backends/arm/operators/op_log.py

-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ def define_node(
3636
output: TosaArg,
3737
) -> None:
3838
assert len(node.all_input_nodes) == 1
39-
assert len(node.users) == 1
4039
assert inputs[0].dtype == output.dtype == ts.DType.FP32
4140

4241
tosa_graph.addOperator(TosaOp.Op().LOG, [inputs[0].name], [output.name])

backends/arm/operators/op_sigmoid.py

-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ def define_node(
3737
) -> None:
3838

3939
assert len(node.all_input_nodes) == 1
40-
assert len(node.users) == 1
4140
assert inputs[0].dtype == output.dtype == ts.DType.FP32
4241

4342
tosa_graph.addOperator(TosaOp.Op().SIGMOID, [inputs[0].name], [output.name])

backends/arm/test/ops/test_avg_pool.py

+33
Original file line numberDiff line numberDiff line change
@@ -172,3 +172,36 @@ def test_avgpool2d_tosa_u85_BI(
172172
common.get_u85_compile_spec(),
173173
(test_data,),
174174
)
175+
176+
reject_data_suite = [
177+
(AvgPool2d(1, 1, 0), torch.rand(2, 5, 5, 5)),
178+
(AvgPool2d((2, 9), 1, 1), torch.rand(1, 16, 5, 32)),
179+
(AvgPool2d(1, 4, 0), torch.rand(1, 10, 10, 10)),
180+
(AvgPool2d((1, 257), 1, 0), torch.rand(1, 16, 5, 300)),
181+
(AvgPool2d((800, 90), 1, 0), torch.rand(1, 16, 850, 100)),
182+
]
183+
184+
@parameterized.expand(reject_data_suite)
185+
def test_reject_avgpool2d_u55_BI(
186+
self,
187+
module: torch.nn.Module,
188+
test_data: torch.tensor,
189+
):
190+
compile_spec = common.get_u55_compile_spec()
191+
tosa_spec = TosaSpecification.create_from_compilespecs(compile_spec)
192+
quantizer = ArmQuantizer(tosa_spec).set_io(get_symmetric_quantization_config())
193+
194+
(
195+
ArmTester(
196+
module,
197+
example_inputs=(test_data,),
198+
compile_spec=compile_spec,
199+
)
200+
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
201+
.export()
202+
.check_count({"torch.ops.aten.avg_pool2d.default": 1})
203+
.check(["torch.ops.quantized_decomposed"])
204+
.to_edge_transform_and_lower()
205+
.check(["executorch_exir_dialects_edge__ops_aten_avg_pool2d_default"])
206+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 0})
207+
)

0 commit comments

Comments
 (0)