From da090a7fbcb560d41cdd713bd427905bd075a057 Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Mon, 3 Feb 2025 14:55:09 +0100 Subject: [PATCH 1/2] Remove unnecessary asserts from op_sigmoid and op_log Signed-off-by: Erik Lundell Change-Id: I35e1de914f650f4e005e81f584033385179906a2 --- backends/arm/operators/op_log.py | 1 - backends/arm/operators/op_sigmoid.py | 1 - 2 files changed, 2 deletions(-) diff --git a/backends/arm/operators/op_log.py b/backends/arm/operators/op_log.py index 7f664900b31..d8a136e37f8 100644 --- a/backends/arm/operators/op_log.py +++ b/backends/arm/operators/op_log.py @@ -36,7 +36,6 @@ def define_node( output: TosaArg, ) -> None: assert len(node.all_input_nodes) == 1 - assert len(node.users) == 1 assert inputs[0].dtype == output.dtype == ts.DType.FP32 tosa_graph.addOperator(TosaOp.Op().LOG, [inputs[0].name], [output.name]) diff --git a/backends/arm/operators/op_sigmoid.py b/backends/arm/operators/op_sigmoid.py index 118c813dcf4..01cf2a6ed04 100644 --- a/backends/arm/operators/op_sigmoid.py +++ b/backends/arm/operators/op_sigmoid.py @@ -37,7 +37,6 @@ def define_node( ) -> None: assert len(node.all_input_nodes) == 1 - assert len(node.users) == 1 assert inputs[0].dtype == output.dtype == ts.DType.FP32 tosa_graph.addOperator(TosaOp.Op().SIGMOID, [inputs[0].name], [output.name]) From e0a9be449b16122f6e8c217072a04dcb01c780b7 Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Tue, 28 Jan 2025 08:39:42 +0100 Subject: [PATCH 2/2] Add is_node_supported checks for 4 ops For convolution, maxpool2d, avgpool2d, and sum. The checks mostly target hardware constraints on Ethos-U55, though convolution also checks for some unsupported behavior. Signed-off-by: Erik Lundell Change-Id: Ic56119e438a476584c4c4b6007a81781907cc96c --- backends/arm/operator_support/__init__.py | 11 ++- .../operator_support/convolution_support.py | 99 +++++++++++++++++++ .../arm/operator_support/pool_2d_support.py | 85 ++++++++++++++++ .../operator_support/reduce_sum_support.py | 51 ++++++++++ .../tosa_supported_operators.py | 4 - backends/arm/test/ops/test_avg_pool.py | 33 +++++++ backends/arm/test/ops/test_conv2d.py | 55 +++++++++++ backends/arm/test/ops/test_max_pool.py | 49 ++++++--- backends/arm/test/ops/test_sum.py | 23 +++++ 9 files changed, 389 insertions(+), 21 deletions(-) create mode 100644 backends/arm/operator_support/convolution_support.py create mode 100644 backends/arm/operator_support/pool_2d_support.py create mode 100644 backends/arm/operator_support/reduce_sum_support.py diff --git a/backends/arm/operator_support/__init__.py b/backends/arm/operator_support/__init__.py index 08f58b1e437..c6895cce492 100644 --- a/backends/arm/operator_support/__init__.py +++ b/backends/arm/operator_support/__init__.py @@ -1,8 +1,15 @@ -# Copyright 2024 Arm Limited and/or its affiliates. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # pyre-unsafe -from . import right_shift_support, to_copy_support, tosa_supported_operators # noqa +from . import ( # noqa + convolution_support, + pool_2d_support, + reduce_sum_support, + right_shift_support, + to_copy_support, + tosa_supported_operators, +) diff --git a/backends/arm/operator_support/convolution_support.py b/backends/arm/operator_support/convolution_support.py new file mode 100644 index 00000000000..ffa74942fa6 --- /dev/null +++ b/backends/arm/operator_support/convolution_support.py @@ -0,0 +1,99 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import cast + +import torch +import torch.fx as fx +from executorch.backends.arm.operator_support.tosa_supported_operators import ( + register_tosa_support_check, + SupportedTOSAOperatorCheck, +) +from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification +from executorch.exir.dialects._ops import ops as exir_ops + + +@register_tosa_support_check +class ConvolutionSupported(SupportedTOSAOperatorCheck): + targets = [exir_ops.edge.aten.convolution.default] + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80+BI"), + TosaSpecification.create_from_string("TOSA-0.80+MI"), + ] + + def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification): + + # Not implemented + transposed = cast(bool, node.args[6]) + output_padding = cast(list[int], node.args[7]) + if transposed: + return False + + for pad in output_padding: + if pad != 0: + return False + + # Hardware specific constraints + if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset): + return True + else: + return self._is_node_supported_u55(node) + + def _is_node_supported_u55(self, node: fx.Node): + """Hardware constraints for Ethos-U-55 case, Vela 4.2.0 (25.02 release)""" + + shape_in = cast(torch.Tensor, node.all_input_nodes[0].meta["val"]).shape + shape_out = node.meta["val"].shape + kernel = cast(fx.Node, node.args[1]).meta["val"].shape + group = cast(int, node.args[8]) + + C_in = shape_in[1] + C_out = shape_out[1] + if (C_in == group) and (C_out % C_in) == 0: + # Depthwise convolution + for dim in shape_in[1:]: + if not 1 <= dim <= 65536: + return False + else: + # Convolution + if not 1 <= C_in <= 65536: + return False + + kernel_w = kernel[2] + kernel_h = kernel[3] if len(kernel) > 3 else 1 + # Kernel condition misses constraint on sum of absolute weights + if not 1 <= kernel_h <= 64 or not 1 <= kernel_w * kernel_h <= 4096: + return False + + if not self._stride_condition(node): + return False + + return True + + def _stride_condition(self, node: fx.Node) -> bool: + """This condition is somewhat complex but boils down + to not supporting stride > 3, unless we have some special conditions. + This condition is a simplified, relaxed version of the hardware constraint, + since the actual constraint requires information not available + here (without a lot of work). + + This means that we might accept ops that are not actually supported. + """ + strides = cast(list[int], node.args[3]) + has_padding = any(pad > 0 for pad in cast(list[int], node.args[4])) + dilations = cast(list[int], node.args[5]) + if len(dilations) == 1: + dilations = [dilations[0]] * 2 + if len(strides) == 1: + strides = [strides[0]] * 2 + + for stride, dilation in zip(strides, dilations): + stride_condition = 1 <= stride <= 3 + dilation_condition = (not has_padding) and (dilation == 1) + if (not stride_condition) and (not dilation_condition): + return False + + return True diff --git a/backends/arm/operator_support/pool_2d_support.py b/backends/arm/operator_support/pool_2d_support.py new file mode 100644 index 00000000000..ae3c7120731 --- /dev/null +++ b/backends/arm/operator_support/pool_2d_support.py @@ -0,0 +1,85 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import cast + +import torch +import torch.fx as fx +from executorch.backends.arm.operator_support.tosa_supported_operators import ( + register_tosa_support_check, + SupportedTOSAOperatorCheck, +) +from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification +from executorch.exir.dialects._ops import ops as exir_ops + + +def kernel_check(kernel: tuple[int, int]) -> bool: + if not (1 <= kernel[0] * kernel[1] <= 65536): + return False + return 1 <= kernel[1] <= 256 + + +def stride_check(strides: tuple[int, int]) -> bool: + return all(1 <= stride <= 3 for stride in strides) + + +def dim_check(shape=torch.Size) -> bool: + check = shape[0] == 1 + for dim in shape: + check &= 1 <= dim <= 65536 + return check + + +@register_tosa_support_check +class AvgPool2dSupported(SupportedTOSAOperatorCheck): + targets = [ + exir_ops.edge.aten.avg_pool2d.default, + ] + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80+BI"), + TosaSpecification.create_from_string("TOSA-0.80+MI"), + ] + + def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification): + if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset): + return True + + # U55 case, Vela 4.2.0 (25.02 release) + shape = cast(torch.Tensor, node.all_input_nodes[0].meta["val"]).shape + kernel = cast(tuple[int, int], node.args[1]) + stride = cast(tuple[int, int], node.args[2]) + if len(node.args) > 3: + # Padding case + if not all(1 <= k <= 8 for k in kernel): + return False + else: + if not kernel_check(kernel): + return False + + return dim_check(shape) and stride_check(stride) + + +@register_tosa_support_check +class MaxPool2dSupported(SupportedTOSAOperatorCheck): + targets = [ + exir_ops.edge.aten.max_pool2d_with_indices.default, + ] + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80+BI"), + TosaSpecification.create_from_string("TOSA-0.80+MI"), + ] + + def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification): + if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset): + return True + + # U55 case, Vela 4.2.0 (25.02 release) + shape = cast(torch.Tensor, node.all_input_nodes[0].meta["val"]).shape + kernel = cast(tuple[int, int], node.args[1]) + stride = cast(tuple[int, int], node.args[2]) + + return kernel_check(kernel) and dim_check(shape) and stride_check(stride) diff --git a/backends/arm/operator_support/reduce_sum_support.py b/backends/arm/operator_support/reduce_sum_support.py new file mode 100644 index 00000000000..1a337be2da1 --- /dev/null +++ b/backends/arm/operator_support/reduce_sum_support.py @@ -0,0 +1,51 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import cast + +import torch.fx as fx +from executorch.backends.arm.operator_support.tosa_supported_operators import ( + register_tosa_support_check, + SupportedTOSAOperatorCheck, +) +from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification +from executorch.exir.dialects._ops import ops as exir_ops + + +@register_tosa_support_check +class SumSupported(SupportedTOSAOperatorCheck): + targets = [exir_ops.edge.aten.sum.dim_IntList] + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80+BI"), + TosaSpecification.create_from_string("TOSA-0.80+MI"), + ] + + def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification): + if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset): + return True + + # U55 case, Vela 4.2.0 (25.02 release) + input_shape = node.all_input_nodes[0].meta["val"].shape + dim_list = cast(list[int], node.args[1]) + dim_list = [dim % len(input_shape) for dim in dim_list] + + for dim in dim_list: + if not 1 <= input_shape[dim] <= 65536: + return False + + # We can't be certain of which dim is the last in memory yet, + # Always go for stricter condition. + pre_R_product = 1.0 + for length in input_shape[:dim]: + pre_R_product *= length + post_R_product = 1.0 + for length in input_shape[dim + 1 :]: + post_R_product *= length + if not 1 <= pre_R_product <= 65536: + return False + if not 1 <= post_R_product <= 65536: + return False + return True diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 36914579fe4..237da6214e8 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -82,7 +82,6 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool: exir_ops.edge.aten.hardsigmoid.default, exir_ops.edge.aten.hardtanh.default, exir_ops.edge.aten.hardswish.default, - exir_ops.edge.aten.convolution.default, exir_ops.edge.aten.div.Tensor, exir_ops.edge.aten.eq.Tensor, exir_ops.edge.aten.exp.default, @@ -97,8 +96,6 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool: exir_ops.edge.aten.mul.Tensor, exir_ops.edge.aten._native_batch_norm_legit_no_training.default, exir_ops.edge.aten.native_layer_norm.default, - exir_ops.edge.aten.avg_pool2d.default, - exir_ops.edge.aten.max_pool2d_with_indices.default, exir_ops.edge.aten.sigmoid.default, exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mm.default, @@ -113,7 +110,6 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool: exir_ops.edge.aten._log_softmax.default, exir_ops.edge.aten.slice_copy.Tensor, exir_ops.edge.aten.sub.Tensor, - exir_ops.edge.aten.sum.dim_IntList, exir_ops.edge.aten.tanh.default, exir_ops.edge.aten.upsample_nearest2d.vec, exir_ops.edge.aten.var.correction, diff --git a/backends/arm/test/ops/test_avg_pool.py b/backends/arm/test/ops/test_avg_pool.py index 16396950dc4..6bea749b8dc 100644 --- a/backends/arm/test/ops/test_avg_pool.py +++ b/backends/arm/test/ops/test_avg_pool.py @@ -172,3 +172,36 @@ def test_avgpool2d_tosa_u85_BI( common.get_u85_compile_spec(), (test_data,), ) + + reject_data_suite = [ + (AvgPool2d(1, 1, 0), torch.rand(2, 5, 5, 5)), + (AvgPool2d((2, 9), 1, 1), torch.rand(1, 16, 5, 32)), + (AvgPool2d(1, 4, 0), torch.rand(1, 10, 10, 10)), + (AvgPool2d((1, 257), 1, 0), torch.rand(1, 16, 5, 300)), + (AvgPool2d((800, 90), 1, 0), torch.rand(1, 16, 850, 100)), + ] + + @parameterized.expand(reject_data_suite) + def test_reject_avgpool2d_u55_BI( + self, + module: torch.nn.Module, + test_data: torch.tensor, + ): + compile_spec = common.get_u55_compile_spec() + tosa_spec = TosaSpecification.create_from_compilespecs(compile_spec) + quantizer = ArmQuantizer(tosa_spec).set_io(get_symmetric_quantization_config()) + + ( + ArmTester( + module, + example_inputs=(test_data,), + compile_spec=compile_spec, + ) + .quantize(Quantize(quantizer, get_symmetric_quantization_config())) + .export() + .check_count({"torch.ops.aten.avg_pool2d.default": 1}) + .check(["torch.ops.quantized_decomposed"]) + .to_edge_transform_and_lower() + .check(["executorch_exir_dialects_edge__ops_aten_avg_pool2d_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 0}) + ) diff --git a/backends/arm/test/ops/test_conv2d.py b/backends/arm/test/ops/test_conv2d.py index 878c65757f2..96464c17097 100644 --- a/backends/arm/test/ops/test_conv2d.py +++ b/backends/arm/test/ops/test_conv2d.py @@ -9,6 +9,7 @@ import torch from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineBI, EthosU85PipelineBI, @@ -406,3 +407,57 @@ def test_conv2d_u85_BI_on_fvp(test_module): test_module, test_module.get_inputs(), aten_op, exir_op, run_on_fvp=True ) pipeline.run() + + +reject_suite = { + "large_stride": Conv2d( + in_channels=1, + out_channels=1, + kernel_size=(2, 4), + stride=(2, 4), + padding=1, + width=10, + height=14, + batches=1, + ), + "large_kernel_height": Conv2d( + in_channels=1, + out_channels=1, + kernel_size=(2, 65), + stride=(1, 1), + padding=0, + width=70, + height=70, + batches=1, + ), + "large_kernel": Conv2d( + in_channels=1, + out_channels=1, + kernel_size=(70, 60), + stride=(1,), + padding=0, + width=80, + height=80, + batches=1, + ), +} + + +@common.parametrize("module", reject_suite) +def test_reject_conv2d_u55_BI( + module: Conv2d, +): + ( + ArmTester( + module, + example_inputs=module.get_inputs(), + compile_spec=common.get_u55_compile_spec(), + ) + .quantize() + .export() + .check_count({"torch.ops.aten.conv2d.default": 1}) + .check(["torch.ops.quantized_decomposed"]) + .to_edge_transform_and_lower() + .check(["executorch_exir_dialects_edge__ops_aten_convolution_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 0}) + ) diff --git a/backends/arm/test/ops/test_max_pool.py b/backends/arm/test/ops/test_max_pool.py index 71d9feca8bf..3752d6c1b2d 100644 --- a/backends/arm/test/ops/test_max_pool.py +++ b/backends/arm/test/ops/test_max_pool.py @@ -216,7 +216,7 @@ def test_maxpool2d_tosa_BI_mult_batches( @parameterized.expand(test_data_suite_mult_batches) @pytest.mark.corstone_fvp @conftest.expectedFailureOnFVP # TODO: MLETORCH-433 - def test_maxpool2d_tosa_u55_BI_mult_batches( + def test_maxpool2d_tosa_u85_BI_mult_batches( self, test_name: str, test_data: torch.Tensor, @@ -224,25 +224,44 @@ def test_maxpool2d_tosa_u55_BI_mult_batches( ): tester = self._test_maxpool2d_tosa_ethos_BI_pipeline( self.MaxPool2d(*model_params), - common.get_u55_compile_spec(), + common.get_u85_compile_spec(), (test_data,), ) if conftest.is_option_enabled("corstone_fvp"): tester.run_method_and_compare_outputs(qtol=1, inputs=(test_data,)) - @parameterized.expand(test_data_suite_mult_batches) - @pytest.mark.corstone_fvp - @conftest.expectedFailureOnFVP # TODO: MLETORCH-433 - def test_maxpool2d_tosa_u85_BI_mult_batches( + reject_data_suite = [ + (MaxPool2d(1, 1, 0), torch.rand(2, 5, 5, 5)), + (MaxPool2d(1, 4, 0), torch.rand(1, 10, 10, 10)), + (MaxPool2d((1, 257), 1, 0), torch.rand(1, 16, 5, 300)), + (MaxPool2d((800, 90), 1, 0), torch.rand(1, 16, 850, 100)), + ] + + @parameterized.expand(reject_data_suite) + def test_reject_maxpool2d_u55_BI( self, - test_name: str, - test_data: torch.Tensor, - model_params: int | Tuple[int, int], + module: torch.nn.Module, + test_data: torch.tensor, ): - tester = self._test_maxpool2d_tosa_ethos_BI_pipeline( - self.MaxPool2d(*model_params), - common.get_u85_compile_spec(), - (test_data,), + compile_spec = common.get_u55_compile_spec() + tosa_spec = TosaSpecification.create_from_compilespecs(compile_spec) + quantizer = ArmQuantizer(tosa_spec).set_io(get_symmetric_quantization_config()) + + ( + ArmTester( + module, + example_inputs=(test_data,), + compile_spec=compile_spec, + ) + .quantize(Quantize(quantizer, get_symmetric_quantization_config())) + .export() + .check_count({"torch.ops.aten.max_pool2d.default": 1}) + .check(["torch.ops.quantized_decomposed"]) + .to_edge_transform_and_lower() + .check( + [ + "executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default" + ] + ) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 0}) ) - if conftest.is_option_enabled("corstone_fvp"): - tester.run_method_and_compare_outputs(qtol=1, inputs=(test_data,)) diff --git a/backends/arm/test/ops/test_sum.py b/backends/arm/test/ops/test_sum.py index 22b41c59d05..5627c55ad9e 100644 --- a/backends/arm/test/ops/test_sum.py +++ b/backends/arm/test/ops/test_sum.py @@ -131,3 +131,26 @@ def test_sum_u85_BI(self, test_data: tuple[exampledata_t]): test_data, common.get_u85_compile_spec(), ) + + reject_inputs = [ + ((torch.rand((65537, 1, 1)), 0, False),), + ((torch.rand((800, 90, 1)), 2, False),), + ((torch.rand((3, 2, 800, 90)), 1, False),), + ] + + @parameterized.expand(reject_inputs) + def test_reject_sum_u55_BI(self, example_inputs): + ( + ArmTester( + TestSum.Sum(), + example_inputs=example_inputs, + compile_spec=common.get_u55_compile_spec(), + ) + .quantize() + .export() + .check_count({"torch.ops.aten.sum.dim_IntList": 1}) + .check(["torch.ops.quantized_decomposed"]) + .to_edge_transform_and_lower() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 0}) + .check(["executorch_exir_dialects_edge__ops_aten_sum_dim_IntList"]) + )