diff --git a/backends/arm/operators/op_avg_pool2d.py b/backends/arm/operators/op_avg_pool2d.py index 170ea5260d8..3e38a7383ed 100644 --- a/backends/arm/operators/op_avg_pool2d.py +++ b/backends/arm/operators/op_avg_pool2d.py @@ -43,12 +43,19 @@ def _build_generic_avgpool2d( output_zp: int, accumulator_type: ts.DType, ) -> None: - input_tensor = inputs[0] + input_tensor = inputs[0] kernel_size_list = inputs[1].special stride_size_list = inputs[2].special + try: pad_size_list = inputs[3].special + pad_size_list = [ + pad_size_list[0], + pad_size_list[0], + pad_size_list[1], + pad_size_list[1], + ] except IndexError: pad_size_list = [0, 0, 0, 0] diff --git a/backends/arm/operators/op_max_pool2d.py b/backends/arm/operators/op_max_pool2d.py index f32300f561d..5305f95880c 100644 --- a/backends/arm/operators/op_max_pool2d.py +++ b/backends/arm/operators/op_max_pool2d.py @@ -42,9 +42,15 @@ def define_node( stride = inputs[2].special try: - padding = [*inputs[3].special, *inputs[3].special] + pad_size_list = inputs[3].special + pad_size_list = [ + pad_size_list[0], + pad_size_list[0], + pad_size_list[1], + pad_size_list[1], + ] except IndexError: - padding = [0, 0, 0, 0] + pad_size_list = [0, 0, 0, 0] accumulator_type = output.dtype @@ -63,7 +69,7 @@ def define_node( attr.PoolAttribute( kernel=kernel_size, stride=stride, - pad=padding, + pad=pad_size_list, input_zp=input_zp, output_zp=output_zp, accum_dtype=accumulator_type, diff --git a/backends/arm/test/ops/test_avg_pool.py b/backends/arm/test/ops/test_avg_pool.py deleted file mode 100644 index fa4662e54f0..00000000000 --- a/backends/arm/test/ops/test_avg_pool.py +++ /dev/null @@ -1,210 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# Copyright 2024-2025 Arm Limited and/or its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import unittest - -from typing import Tuple - -import pytest - -import torch -from executorch.backends.arm.quantizer.arm_quantizer import ( - EthosUQuantizer, - get_symmetric_quantization_config, - TOSAQuantizer, -) -from executorch.backends.arm.test import common, conftest -from executorch.backends.arm.test.tester.arm_tester import ArmTester -from executorch.backends.arm.tosa_specification import TosaSpecification -from executorch.backends.xnnpack.test.tester.tester import Quantize -from executorch.exir.backend.backend_details import CompileSpec -from parameterized import parameterized - - -test_data_suite = [ - # (test_name, test_data, [kernel_size, stride, padding]) - ("zeros", torch.zeros(1, 16, 50, 32), [4, 2, 0]), - ("ones", torch.zeros(1, 16, 50, 32), [4, 2, 0]), - ("rand", torch.rand(1, 16, 50, 32), [4, 2, 0]), - ("randn", torch.randn(1, 16, 50, 32), [4, 2, 0]), -] - - -class TestAvgPool2d(unittest.TestCase): - """Tests AvgPool2d.""" - - class AvgPool2d(torch.nn.Module): - def __init__( - self, - kernel_size: int | Tuple[int, int], - stride: int | Tuple[int, int], - padding: int | Tuple[int, int], - ): - super().__init__() - self.avg_pool_2d = torch.nn.AvgPool2d( - kernel_size=kernel_size, stride=stride, padding=padding - ) - - def forward(self, x): - return self.avg_pool_2d(x) - - def _test_avgpool2d_tosa_MI_pipeline( - self, module: torch.nn.Module, test_data: Tuple[torch.tensor] - ): - ( - ArmTester( - module, - example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), - ) - .export() - .check(["torch.ops.aten.avg_pool2d.default"]) - .check_not(["torch.ops.quantized_decomposed"]) - .to_edge() - .partition() - .check_not(["executorch_exir_dialects_edge__ops_aten_avg_pool2d_default"]) - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - .to_executorch() - .run_method_and_compare_outputs(inputs=test_data) - ) - - def _test_avgpool2d_tosa_BI_pipeline( - self, module: torch.nn.Module, test_data: Tuple[torch.tensor] - ): - tosa_spec = TosaSpecification.create_from_string("TOSA-0.80+BI") - compile_spec = common.get_tosa_compile_spec(tosa_spec) - quantizer = TOSAQuantizer(tosa_spec).set_io(get_symmetric_quantization_config()) - ( - ArmTester( - module, - example_inputs=test_data, - compile_spec=compile_spec, - ) - .quantize(Quantize(quantizer, get_symmetric_quantization_config())) - .export() - .check_count({"torch.ops.aten.avg_pool2d.default": 1}) - .check(["torch.ops.quantized_decomposed"]) - .to_edge() - .partition() - .check_not(["executorch_exir_dialects_edge__ops_aten_avg_pool2d_default"]) - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - .to_executorch() - .run_method_and_compare_outputs(inputs=test_data, qtol=1) - ) - - def _test_avgpool2d_tosa_ethos_BI_pipeline( - self, - module: torch.nn.Module, - compile_spec: CompileSpec, - test_data: Tuple[torch.tensor], - ): - quantizer = EthosUQuantizer(compile_spec).set_io( - get_symmetric_quantization_config() - ) - tester = ( - ArmTester( - module, - example_inputs=test_data, - compile_spec=compile_spec, - ) - .quantize(Quantize(quantizer, get_symmetric_quantization_config())) - .export() - .check_count({"torch.ops.aten.avg_pool2d.default": 1}) - .check(["torch.ops.quantized_decomposed"]) - .to_edge() - .partition() - .check_not(["executorch_exir_dialects_edge__ops_aten_avg_pool2d_default"]) - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - .to_executorch() - .serialize() - ) - if conftest.is_option_enabled("corstone_fvp"): - tester.run_method_and_compare_outputs(qtol=1, inputs=test_data) - - @parameterized.expand(test_data_suite) - def test_avgpool2d_tosa_MI( - self, - test_name: str, - test_data: torch.Tensor, - model_params: int | Tuple[int, int], - ): - self._test_avgpool2d_tosa_MI_pipeline( - self.AvgPool2d(*model_params), (test_data,) - ) - - @parameterized.expand(test_data_suite) - def test_avgpool2d_tosa_BI( - self, - test_name: str, - test_data: torch.Tensor, - model_params: int | Tuple[int, int], - ): - self._test_avgpool2d_tosa_BI_pipeline( - self.AvgPool2d(*model_params), (test_data,) - ) - - @parameterized.expand(test_data_suite) - @pytest.mark.corstone_fvp - def test_avgpool2d_tosa_u55_BI( - self, - test_name: str, - test_data: torch.Tensor, - model_params: int | Tuple[int, int], - ): - self._test_avgpool2d_tosa_ethos_BI_pipeline( - self.AvgPool2d(*model_params), - common.get_u55_compile_spec(), - (test_data,), - ) - - @parameterized.expand(test_data_suite) - @pytest.mark.corstone_fvp - def test_avgpool2d_tosa_u85_BI( - self, - test_name: str, - test_data: torch.Tensor, - model_params: int | Tuple[int, int], - ): - self._test_avgpool2d_tosa_ethos_BI_pipeline( - self.AvgPool2d(*model_params), - common.get_u85_compile_spec(), - (test_data,), - ) - - reject_data_suite = [ - (AvgPool2d(1, 1, 0), torch.rand(2, 5, 5, 5)), - (AvgPool2d((2, 9), 1, 1), torch.rand(1, 16, 5, 32)), - (AvgPool2d(1, 4, 0), torch.rand(1, 10, 10, 10)), - (AvgPool2d((1, 257), 1, 0), torch.rand(1, 16, 5, 300)), - (AvgPool2d((800, 90), 1, 0), torch.rand(1, 16, 850, 100)), - ] - - @parameterized.expand(reject_data_suite) - def test_reject_avgpool2d_u55_BI( - self, - module: torch.nn.Module, - test_data: torch.tensor, - ): - compile_spec = common.get_u55_compile_spec() - quantizer = EthosUQuantizer(compile_spec).set_io( - get_symmetric_quantization_config() - ) - - ( - ArmTester( - module, - example_inputs=(test_data,), - compile_spec=compile_spec, - ) - .quantize(Quantize(quantizer, get_symmetric_quantization_config())) - .export() - .check_count({"torch.ops.aten.avg_pool2d.default": 1}) - .check(["torch.ops.quantized_decomposed"]) - .to_edge_transform_and_lower() - .check(["executorch_exir_dialects_edge__ops_aten_avg_pool2d_default"]) - .check_count({"torch.ops.higher_order.executorch_call_delegate": 0}) - ) diff --git a/backends/arm/test/ops/test_avg_pool2d.py b/backends/arm/test/ops/test_avg_pool2d.py new file mode 100644 index 00000000000..2a50ef38834 --- /dev/null +++ b/backends/arm/test/ops/test_avg_pool2d.py @@ -0,0 +1,184 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# Copyright 2024-2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Tuple + +import torch + +from executorch.backends.arm.test import common + +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineBI, + EthosU85PipelineBI, + OpNotSupportedPipeline, + TosaPipelineBI, + TosaPipelineMI, +) + + +aten_op = "torch.ops.aten.avg_pool2d.default" +exir_op = "executorch_exir_dialects_edge__ops_aten_avg_pool2d_default" +input_t = Tuple[torch.Tensor] + + +class AvgPool2d(torch.nn.Module): + def __init__( + self, + kernel_size: int | Tuple[int, int], + stride: int | Tuple[int, int], + padding: int | Tuple[int, int], + ): + super().__init__() + self.avg_pool_2d = torch.nn.AvgPool2d( + kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x): + return self.avg_pool_2d(x) + + +test_modules = { + "zeros": (AvgPool2d(4, 2, 0), (torch.zeros(1, 16, 50, 32),)), + "ones": (AvgPool2d(4, 2, 0), (torch.ones(1, 16, 50, 32),)), + "rand": (AvgPool2d(4, 2, 0), (torch.rand(1, 16, 50, 32),)), + "randn": (AvgPool2d(4, 2, 0), (torch.randn(1, 16, 50, 32),)), + "kernel_3x3_stride_1_pad_1": ( + AvgPool2d((3, 3), (1, 1), 1), + (torch.rand(1, 16, 50, 32),), + ), + "kernel_3x2_stride_1x2_pad_1x0": ( + AvgPool2d((3, 2), (1, 2), (1, 0)), + (torch.rand(1, 16, 50, 32),), + ), + "kernel_4x6_stride_1x2_pad_2x3": ( + AvgPool2d((4, 6), (1, 2), (2, 3)), + (torch.rand(1, 16, 50, 32),), + ), +} + + +@common.parametrize("test_module", test_modules) +def test_avgpool2d_tosa_MI(test_module): + model, input_tensor = test_module + + pipeline = TosaPipelineMI[input_t](model, input_tensor, aten_op, exir_op) + pipeline.change_args("run_method_and_compare_outputs", qtol=1, atol=1, rtol=1) + pipeline.run() + + +@common.parametrize("test_module", test_modules) +def test_avgpool2d_tosa_BI(test_module): + model, input_tensor = test_module + + pipeline = TosaPipelineBI[input_t]( + model, + input_tensor, + aten_op, + exir_op, + symmetric_io_quantization=True, + ) + pipeline.change_args("run_method_and_compare_outputs", qtol=1, atol=1, rtol=1) + pipeline.run() + + +@common.parametrize("test_module", test_modules) +def test_avgpool2d_u55_BI(test_module): + model, input_tensor = test_module + + pipeline = EthosU55PipelineBI[input_t]( + model, + input_tensor, + aten_op, + exir_op, + run_on_fvp=False, + symmetric_io_quantization=True, + ) + + pipeline.run() + + +@common.parametrize("test_module", test_modules) +def test_avgpool2d_u85_BI(test_module): + model, input_tensor = test_module + + pipeline = EthosU85PipelineBI[input_t]( + model, + input_tensor, + aten_op, + exir_op, + run_on_fvp=False, + symmetric_io_quantization=True, + ) + + pipeline.run() + + +@common.parametrize("test_module", test_modules) +@common.SkipIfNoCorstone300 +def test_avgpool2d_u55_BI_on_fvp(test_module): + model, input_tensor = test_module + + pipeline = EthosU55PipelineBI[input_t]( + model, + input_tensor, + aten_op, + exir_op, + run_on_fvp=True, + symmetric_io_quantization=True, + ) + pipeline.change_args("run_method_and_compare_outputs", qtol=1, atol=1, rtol=1) + pipeline.run() + + +@common.parametrize("test_module", test_modules) +@common.SkipIfNoCorstone320 +def test_avgpool2d_u85_BI_on_fvp(test_module): + model, input_tensor = test_module + + pipeline = EthosU85PipelineBI[input_t]( + model, + input_tensor, + aten_op, + exir_op, + run_on_fvp=True, + symmetric_io_quantization=True, + ) + pipeline.change_args("run_method_and_compare_outputs", qtol=1, atol=1, rtol=1) + + pipeline.run() + + +reject_modules = { + "kernel_1x1_stride_1_pad_0": (AvgPool2d(1, 1, 0), torch.rand(2, 5, 5, 5)), + "kernel_2x9_stride_1_pad_1": (AvgPool2d((2, 9), 1, 1), torch.rand(1, 16, 5, 32)), + "kernel_1x4_stride_0_pad_0": (AvgPool2d(1, 4, 0), torch.rand(1, 10, 10, 10)), + "kernel_1x257_stride_1_pad_0_large": ( + AvgPool2d((1, 257), 1, 0), + torch.rand(1, 16, 5, 300), + ), + "kernel_800x90_stride_1_pad_0_extreme": ( + AvgPool2d((800, 90), 1, 0), + torch.rand(1, 16, 850, 100), + ), +} + + +@common.parametrize("reject_module", reject_modules) +def test_reject_avgpool2d(reject_module): + + model, test_data = reject_module + + pipeline = OpNotSupportedPipeline[input_t]( + module=model, + test_data=(test_data,), + tosa_version="TOSA-0.80+BI", + non_delegated_ops={}, + n_expected_delegates=0, + ) + pipeline.run() diff --git a/backends/arm/test/ops/test_clamp.py b/backends/arm/test/ops/test_clamp.py index f379732343e..368f7967433 100644 --- a/backends/arm/test/ops/test_clamp.py +++ b/backends/arm/test/ops/test_clamp.py @@ -1,167 +1,159 @@ # Copyright 2025 Arm Limited and/or its affiliates. -# All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import unittest from numbers import Number from typing import Tuple, Union -import pytest import torch -from executorch.backends.arm.quantizer.arm_quantizer import ( - EthosUQuantizer, - get_symmetric_quantization_config, - TOSAQuantizer, +from executorch.backends.arm.test import common + +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineBI, + EthosU85PipelineBI, + TosaPipelineBI, + TosaPipelineMI, ) -from executorch.backends.arm.test import common, conftest -from executorch.backends.arm.test.tester.arm_tester import ArmTester -from executorch.backends.arm.tosa_specification import TosaSpecification -from executorch.backends.xnnpack.test.tester.tester import Quantize -from executorch.exir.backend.compile_spec_schema import CompileSpec -from parameterized import parameterized - - -test_data_suite = [ - # (test_name, test_data, min, max) - ("rank_1", torch.rand(10) * 2, -1.0, 1.0), - ("rank_2", torch.rand(1, 35), 0.5, 0.8), - ("rank_3", torch.ones(1, 10, 10), -1, -1), - ("rank_4", torch.rand(1, 10, 10, 1) * 2, -0.1, 2.0), - ("rank_4_mixed_min_max_dtype", torch.rand(1, 10, 10, 5) + 10, 8.0, 10), - ("rank_4_no_min", torch.rand(1, 10, 10, 1) * 10, None, 5), - ("rank_4_no_max", torch.rand(1, 10, 10, 1) - 3, -3.3, None), -] - - -class TestClamp(unittest.TestCase): - """Tests Clamp Operator.""" - - class Clamp(torch.nn.Module): - def __init__( - self, - min: Union[torch.Tensor, Number, None], - max: Union[torch.Tensor, Number, None], - ): - super().__init__() - - self.clamp_min = min - self.clamp_max = max - - def forward(self, x): - return torch.clamp(x, self.clamp_min, self.clamp_max) - - def _test_clamp_tosa_MI_pipeline( - self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] - ): - ( - ArmTester( - module, - example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), - ) - .export() - .check(["torch.ops.aten.clamp.default"]) - .check_not(["torch.ops.quantized_decomposed"]) - .to_edge_transform_and_lower() - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - .to_executorch() - .run_method_and_compare_outputs(inputs=test_data) - ) - - def _test_clamp_tosa_BI_pipeline( - self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] - ): - tosa_spec = TosaSpecification.create_from_string("TOSA-0.80+BI") - compile_spec = common.get_tosa_compile_spec(tosa_spec) - quantizer = TOSAQuantizer(tosa_spec).set_io(get_symmetric_quantization_config()) - ( - ArmTester( - module, - example_inputs=test_data, - compile_spec=compile_spec, - ) - .quantize(Quantize(quantizer, get_symmetric_quantization_config())) - .export() - .check_count({"torch.ops.aten.clamp.default": 1}) - .check(["torch.ops.quantized_decomposed"]) - .to_edge_transform_and_lower() - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - .to_executorch() - .run_method_and_compare_outputs(inputs=test_data) - ) - - def _test_clamp_ethos_pipeline( - self, - compile_spec: list[CompileSpec], - module: torch.nn.Module, - test_data: Tuple[torch.tensor], - ): - quantizer = EthosUQuantizer(compile_spec).set_io( - get_symmetric_quantization_config() - ) - tester = ( - ArmTester( - module, - example_inputs=test_data, - compile_spec=compile_spec, - ) - .quantize(Quantize(quantizer, get_symmetric_quantization_config())) - .export() - .check_count({"torch.ops.aten.clamp.default": 1}) - .check(["torch.ops.quantized_decomposed"]) - .to_edge_transform_and_lower() - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - .to_executorch() - .serialize() - ) - if conftest.is_option_enabled("corstone_fvp"): - tester.run_method_and_compare_outputs(qtol=1, inputs=test_data) - - @parameterized.expand(test_data_suite) - def test_clamp_tosa_MI( - self, - test_name: str, - test_data: torch.Tensor, - min: Union[torch.Tensor, Number, None], - max: Union[torch.Tensor, Number, None], - ): - self._test_clamp_tosa_MI_pipeline(self.Clamp(min, max), (test_data,)) - @parameterized.expand(test_data_suite) - def test_clamp_tosa_BI( - self, - test_name: str, - test_data: torch.Tensor, - min: Union[torch.Tensor, Number, None], - max: Union[torch.Tensor, Number, None], - ): - self._test_clamp_tosa_BI_pipeline(self.Clamp(min, max), (test_data,)) - @parameterized.expand(test_data_suite) - @pytest.mark.corstone_fvp - def test_clamp_tosa_u55_BI( - self, - test_name: str, - test_data: torch.Tensor, - min: Union[torch.Tensor, Number, None], - max: Union[torch.Tensor, Number, None], - ): - self._test_clamp_ethos_pipeline( - common.get_u55_compile_spec(), self.Clamp(min, max), (test_data,) - ) +aten_op = "torch.ops.aten.clamp.default" +exir_op = "executorch_exir_dialects_edge__ops_aten_clamp_default" +input_t = Tuple[torch.Tensor] + +test_data_suite = { + # test_name: (test_data, min, max) + "rank_1": (torch.rand(10) * 2, -1.0, 1.0), + "rank_2": (torch.rand(1, 35), 0.5, 0.8), + "rank_3": (torch.ones(1, 10, 10), -1, -1), + "rank_4": (torch.rand(1, 10, 10, 1) * 2, -0.1, 2.0), + "rank_4_mixed_min_max_dtype": (torch.rand(1, 10, 10, 5) + 10, 8.0, 10), + "rank_4_no_min": (torch.rand(1, 10, 10, 1) * 10, None, 5), + "rank_4_no_max": (torch.rand(1, 10, 10, 1) - 3, -3.3, None), +} - @parameterized.expand(test_data_suite) - @pytest.mark.corstone_fvp - def test_clamp_tosa_u85_BI( + +class Clamp(torch.nn.Module): + def __init__( self, - test_name: str, - test_data: torch.Tensor, - min: Union[torch.Tensor, Number, None], - max: Union[torch.Tensor, Number, None], + clamp_min: Union[torch.Tensor, Number, None], + clamp_max: Union[torch.Tensor, Number, None], ): - self._test_clamp_ethos_pipeline( - common.get_u85_compile_spec(), self.Clamp(min, max), (test_data,) - ) + super().__init__() + + self.clamp_min = clamp_min + self.clamp_max = clamp_max + + def forward(self, x): + return torch.clamp(x, self.clamp_min, self.clamp_max) + + +@common.parametrize("test_data", test_data_suite) +def test_clamp_tosa_MI(test_data): + + input_tensor, min_val, max_val = test_data + model = Clamp(min_val, max_val) + + pipeline = TosaPipelineMI[input_t]( + model, + (input_tensor,), + aten_op, + exir_op, + ) + + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +def test_clamp_tosa_BI(test_data): + + input_tensor, min_val, max_val = test_data + model = Clamp(min_val, max_val) + + pipeline = TosaPipelineBI[input_t]( + model, + (input_tensor,), + aten_op, + exir_op, + symmetric_io_quantization=True, + ) + pipeline.change_args("run_method_and_compare_outputs", qtol=1) + + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +def test_clamp_u55_BI(test_data): + + input_tensor, min_val, max_val = test_data + model = Clamp(min_val, max_val) + + pipeline = EthosU55PipelineBI[input_t]( + model, + (input_tensor,), + aten_op, + exir_op, + run_on_fvp=False, + symmetric_io_quantization=True, + ) + + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +def test_clamp_u85_BI(test_data): + + input_tensor, min_val, max_val = test_data + model = Clamp(min_val, max_val) + + pipeline = EthosU85PipelineBI[input_t]( + model, + (input_tensor,), + aten_op, + exir_op, + run_on_fvp=False, + symmetric_io_quantization=True, + ) + + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.SkipIfNoCorstone300 +def test_clamp_u55_BI_on_fvp(test_data): + + input_tensor, min_val, max_val = test_data + model = Clamp(min_val, max_val) + + pipeline = EthosU55PipelineBI[input_t]( + model, + (input_tensor,), + aten_op, + exir_op, + run_on_fvp=True, + symmetric_io_quantization=True, + ) + + pipeline.change_args("run_method_and_compare_outputs", qtol=1) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.SkipIfNoCorstone320 +def test_clamp_u85_BI_on_fvp(test_data): + + input_tensor, min_val, max_val = test_data + model = Clamp(min_val, max_val) + + pipeline = EthosU85PipelineBI[input_t]( + model, + (input_tensor,), + aten_op, + exir_op, + run_on_fvp=True, + symmetric_io_quantization=True, + ) + pipeline.change_args("run_method_and_compare_outputs", qtol=1) + + pipeline.run() diff --git a/backends/arm/test/ops/test_clone.py b/backends/arm/test/ops/test_clone.py index 543ae9ac40f..2aad62ece24 100644 --- a/backends/arm/test/ops/test_clone.py +++ b/backends/arm/test/ops/test_clone.py @@ -1,5 +1,4 @@ # Copyright 2024-2025 Arm Limited and/or its affiliates. -# All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -8,78 +7,135 @@ # Tests the clone op which copies the data of the input tensor (possibly with new data format) # -import unittest from typing import Tuple +import pytest import torch -from executorch.backends.arm.quantizer.arm_quantizer import ( - get_symmetric_quantization_config, - TOSAQuantizer, -) from executorch.backends.arm.test import common -from executorch.backends.arm.test.tester.arm_tester import ArmTester -from executorch.backends.arm.tosa_specification import TosaSpecification - -from executorch.backends.xnnpack.test.tester.tester import Quantize - -from parameterized import parameterized - - -class TestSimpleClone(unittest.TestCase): - """Tests clone.""" - - class Clone(torch.nn.Module): - sizes = [10, 15, 50, 100] - test_parameters = [(torch.ones(n),) for n in sizes] - - def __init__(self): - super().__init__() - - def forward(self, x: torch.Tensor): - x = x.clone() - return x - - def _test_clone_tosa_MI_pipeline( - self, module: torch.nn.Module, test_data: torch.Tensor - ): - ( - ArmTester( - module, - example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), - ) - .export() - .check_count({"torch.ops.aten.clone.default": 1}) - .to_edge() - .partition() - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - .to_executorch() - .run_method_and_compare_outputs(inputs=test_data) - ) - - def _test_clone_tosa_BI_pipeline( - self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] - ): - tosa_spec = TosaSpecification.create_from_string("TOSA-0.80+BI") - compile_spec = common.get_tosa_compile_spec(tosa_spec) - quantizer = TOSAQuantizer(tosa_spec).set_io(get_symmetric_quantization_config()) - ( - ArmTester(module, example_inputs=test_data, compile_spec=compile_spec) - .quantize(Quantize(quantizer, get_symmetric_quantization_config())) - .export() - .check_count({"torch.ops.aten.clone.default": 1}) - .to_edge() - .partition() - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - .to_executorch() - .run_method_and_compare_outputs(inputs=test_data, qtol=1) - ) - - @parameterized.expand(Clone.test_parameters) - def test_clone_tosa_MI(self, test_tensor: torch.Tensor): - self._test_clone_tosa_MI_pipeline(self.Clone(), (test_tensor,)) - - @parameterized.expand(Clone.test_parameters) - def test_clone_tosa_BI(self, test_tensor: torch.Tensor): - self._test_clone_tosa_BI_pipeline(self.Clone(), (test_tensor,)) + +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineBI, + EthosU85PipelineBI, + TosaPipelineBI, + TosaPipelineMI, +) + + +aten_op = "torch.ops.aten.clone.default" +exir_op = "executorch_exir_dialects_edge__ops_aten_clone_default" + +input_t = Tuple[torch.Tensor] + + +class Clone(torch.nn.Module): + """A simple module that clones an input tensor.""" + + def forward(self, x: torch.Tensor): + return x.clone() + + +test_data_suite = { + "ones_1D_10": (torch.ones(10),), + "ones_1D_50": (torch.ones(50),), + "rand_1D_20": (torch.rand(20),), + "rand_2D_10x10": (torch.rand(10, 10),), + "rand_3D_5x5x5": (torch.rand(5, 5, 5),), + "rand_4D_2x3x4x5": (torch.rand(2, 3, 4, 5),), + "large_tensor": (torch.rand(1000),), +} + + +@common.parametrize("test_data", test_data_suite) +def test_clone_tosa_MI(test_data: Tuple[torch.Tensor]): + + pipeline = TosaPipelineMI[input_t]( + Clone(), + test_data, + aten_op, + exir_op, + ) + + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +def test_clone_tosa_BI(test_data): + pipeline = TosaPipelineBI[input_t]( + Clone(), + test_data, + aten_op, + exir_op, + symmetric_io_quantization=True, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@pytest.mark.xfail( + reason="Empty subgraph leads to Vela compilation failure. See: https://jira.arm.com/browse/MLBEDSW-10477" +) +def test_clone_u55_BI(test_data): + pipeline = EthosU55PipelineBI[input_t]( + Clone(), + test_data, + aten_op, + exir_op, + run_on_fvp=False, + symmetric_io_quantization=True, + ) + + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@pytest.mark.xfail( + reason="Empty subgraph leads to Vela compilation failure. See: https://jira.arm.com/browse/MLBEDSW-10477" +) +def test_clone_u85_BI(test_data): + pipeline = EthosU85PipelineBI[input_t]( + Clone(), + test_data, + aten_op, + exir_op, + run_on_fvp=False, + symmetric_io_quantization=True, + ) + + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@pytest.mark.xfail( + reason="Empty subgraph leads to Vela compilation failure. See: https://jira.arm.com/browse/MLBEDSW-10477" +) +@common.SkipIfNoCorstone300 +def test_clone_u55_BI_on_fvp(test_data): + pipeline = EthosU55PipelineBI[input_t]( + Clone(), + test_data, + aten_op, + exir_op, + run_on_fvp=True, + symmetric_io_quantization=True, + ) + + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@pytest.mark.xfail( + reason="Empty subgraph leads to Vela compilation failure. See: https://jira.arm.com/browse/MLBEDSW-10477" +) +@common.SkipIfNoCorstone320 +def test_clone_u85_BI_on_fvp(test_data): + pipeline = EthosU85PipelineBI[input_t]( + Clone(), + test_data, + aten_op, + exir_op, + run_on_fvp=True, + symmetric_io_quantization=True, + ) + + pipeline.run() diff --git a/backends/arm/test/ops/test_conv1d.py b/backends/arm/test/ops/test_conv1d.py index 92da09a5ef3..a1ba23ac73a 100644 --- a/backends/arm/test/ops/test_conv1d.py +++ b/backends/arm/test/ops/test_conv1d.py @@ -1,20 +1,24 @@ # Copyright 2024-2025 Arm Limited and/or its affiliates. -# All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import unittest from typing import List, Tuple, Union -import pytest - import torch -from executorch.backends.arm.test import common, conftest -from executorch.backends.arm.test.tester.arm_tester import ArmTester -from executorch.exir.backend.backend_details import CompileSpec -from parameterized import parameterized +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineBI, + EthosU85PipelineBI, + TosaPipelineBI, + TosaPipelineMI, +) + +aten_op = "torch.ops.aten.conv1d.default" +exir_op = "executorch_exir_dialects_edge__ops_aten_convolution_default" + +input_t = Tuple[torch.Tensor] class Conv1d(torch.nn.Module): @@ -245,107 +249,93 @@ def forward(self, x): batches=1, ) -# Shenanigan to get a nicer output when test fails. With unittest it looks like: -# FAIL: test_conv1d_tosa_BI_2_3x3_1x3x12x12_st2_pd1 -testsuite = [ - ("2_3x2x40_nobias", conv1d_2_3x2x40_nobias), - ("3_1x3x256_st1", conv1d_3_1x3x256_st1), - ("3_1x3x12_st2_pd1", conv1d_3_1x3x12_st2_pd1), - ("1_1x2x128_st1", conv1d_1_1x2x128_st1), - ("2_1x2x14_st2", conv1d_2_1x2x14_st2), - ("5_3x2x128_st1", conv1d_5_3x2x128_st1), - ("3_1x3x224_st2_pd1", conv1d_3_1x3x224_st2_pd1), - ("7_1x3x16_st2_pd1_dl2_needs_adjust_pass", conv1d_7_1x3x16_st2_pd1_dl2), - ("7_1x3x15_st1_pd0_dl1_needs_adjust_pass", conv1d_7_1x3x15_st1_pd0_dl1), - ("5_1x3x14_st5_pd0_dl1_needs_adjust_pass", conv1d_5_1x3x14_st5_pd0_dl1), - ("5_1x3x9_st5_pd0_dl1_needs_adjust_pass", conv1d_5_1x3x9_st5_pd0_dl1), - ("two_conv1d_nobias", two_conv1d_nobias), - ("two_conv1d", two_conv1d), -] - - -class TestConv1D(unittest.TestCase): - def _test_conv1d_tosa_MI_pipeline( - self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] - ): - ( - ArmTester( - module, - example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec( - "TOSA-0.80+MI", - ), - ) - .export() - .to_edge() - .partition() - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - .check_not(["executorch_exir_dialects_edge__ops_aten_convolution_default"]) - .to_executorch() - .run_method_and_compare_outputs(inputs=test_data) - ) - - def _test_conv1d_tosa_BI_pipeline( - self, - module: torch.nn.Module, - test_data: Tuple[torch.Tensor], - ): - ( - ArmTester( - module, - example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec( - "TOSA-0.80+BI", - ), - ) - .quantize() - .export() - .to_edge() - .partition() - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - .check_not(["executorch_exir_dialects_edge__ops_aten_convolution_default"]) - .to_executorch() - .run_method_and_compare_outputs(inputs=test_data, qtol=1) - ) - - def _test_conv1d_ethosu_BI_pipeline( - self, - module: torch.nn.Module, - compile_spec: CompileSpec, - test_data: Tuple[torch.Tensor], - ): - tester = ( - ArmTester(module, example_inputs=test_data, compile_spec=compile_spec) - .quantize() - .export() - .to_edge() - .partition() - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - .check_not(["executorch_exir_dialects_edge__ops_aten_convolution_default"]) - .to_executorch() - .serialize() - ) - if conftest.is_option_enabled("corstone_fvp"): - tester.run_method_and_compare_outputs(qtol=1, inputs=test_data) - - @parameterized.expand(testsuite) - def test_conv1d_tosa_MI(self, test_name, model): - self._test_conv1d_tosa_MI_pipeline(model, model.get_inputs()) - - @parameterized.expand(testsuite) - def test_conv1d_tosa_BI(self, test_name, model): - self._test_conv1d_tosa_BI_pipeline(model, model.get_inputs()) - - @parameterized.expand(testsuite) - @pytest.mark.corstone_fvp - def test_conv1d_u55_BI(self, test_name, model): - self._test_conv1d_ethosu_BI_pipeline( - model, common.get_u55_compile_spec(), model.get_inputs() - ) - - @parameterized.expand(testsuite) - @pytest.mark.corstone_fvp - def test_conv1d_u85_BI(self, test_name, model): - self._test_conv1d_ethosu_BI_pipeline( - model, common.get_u85_compile_spec(), model.get_inputs() - ) +test_modules = { + "2_3x2x40_nobias": conv1d_2_3x2x40_nobias, + "3_1x3x256_st1": conv1d_3_1x3x256_st1, + "3_1x3x12_st2_pd1": conv1d_3_1x3x12_st2_pd1, + "1_1x2x128_st1": conv1d_1_1x2x128_st1, + "2_1x2x14_st2": conv1d_2_1x2x14_st2, + "5_3x2x128_st1": conv1d_5_3x2x128_st1, + "3_1x3x224_st2_pd1": conv1d_3_1x3x224_st2_pd1, + "7_1x3x16_st2_pd1_dl2_needs_adjust_pass": conv1d_7_1x3x16_st2_pd1_dl2, + "7_1x3x15_st1_pd0_dl1_needs_adjust_pass": conv1d_7_1x3x15_st1_pd0_dl1, + "5_1x3x14_st5_pd0_dl1_needs_adjust_pass": conv1d_5_1x3x14_st5_pd0_dl1, + "5_1x3x9_st5_pd0_dl1_needs_adjust_pass": conv1d_5_1x3x9_st5_pd0_dl1, + "two_conv1d_nobias": two_conv1d_nobias, + "two_conv1d": two_conv1d, +} + + +@common.parametrize("test_module", test_modules) +def test_convolution_1d_tosa_MI(test_module): + pipeline = TosaPipelineMI[input_t]( + test_module, + test_module.get_inputs(), + aten_op, + exir_op, + ) + pipeline.run() + + +@common.parametrize("test_module", test_modules) +def test_convolution_1d_tosa_BI(test_module): + pipeline = TosaPipelineBI[input_t]( + test_module, + test_module.get_inputs(), + aten_op, + exir_op, + ) + pipeline.change_args("run_method_and_compare_outputs", qtol=1) + pipeline.run() + + +@common.parametrize("test_module", test_modules) +def test_convolution_1d_u55_BI(test_module): + pipeline = EthosU55PipelineBI[input_t]( + test_module, + test_module.get_inputs(), + aten_op, + exir_op, + run_on_fvp=False, + ) + pipeline.run() + + +@common.parametrize("test_module", test_modules) +def test_convolution_1d_u85_BI(test_module): + pipeline = EthosU85PipelineBI[input_t]( + test_module, + test_module.get_inputs(), + aten_op, + exir_op, + run_on_fvp=False, + ) + pipeline.run() + + +@common.parametrize("test_module", test_modules) +@common.SkipIfNoCorstone300 +def test_convolution_1d_u55_BI_on_fvp(test_module): + pipeline = EthosU55PipelineBI[input_t]( + test_module, + test_module.get_inputs(), + aten_op, + exir_op, + run_on_fvp=True, + ) + pipeline.change_args("run_method_and_compare_outputs", qtol=1) + pipeline.run() + + +@common.parametrize("test_module", test_modules) +@common.SkipIfNoCorstone320 +def test_convolution_1d_u85_BI_on_fvp(test_module): + pipeline = EthosU85PipelineBI[input_t]( + test_module, + test_module.get_inputs(), + aten_op, + exir_op, + run_on_fvp=True, + ) + pipeline.change_args("run_method_and_compare_outputs", qtol=1) + pipeline.run() diff --git a/backends/arm/test/ops/test_conv2d.py b/backends/arm/test/ops/test_conv2d.py index b41738b3e8d..8083b2ecf71 100644 --- a/backends/arm/test/ops/test_conv2d.py +++ b/backends/arm/test/ops/test_conv2d.py @@ -328,7 +328,7 @@ def forward(self, x): ) # Shenanigan to get a nicer output when test fails. With unittest it looks like: -# FAIL: test_conv2d_tosa_BI_2_3x3_1x3x12x12_st2_pd1 +# FAIL: test_convolution_2d_tosa_BI_2_3x3_1x3x12x12_st2_pd1 test_modules = { "2x2_3x2x40x40_nobias": conv2d_2x2_3x2x40x40_nobias, "3x3_1x3x256x256_st1": conv2d_3x3_1x3x256x256_st1, @@ -358,7 +358,7 @@ def forward(self, x): @common.parametrize("test_module", test_modules) -def test_conv2d_tosa_MI(test_module): +def test_convolution_2d_tosa_MI(test_module): pipeline = TosaPipelineMI[input_t]( test_module, test_module.get_inputs(), aten_op, exir_op ) @@ -366,7 +366,7 @@ def test_conv2d_tosa_MI(test_module): @common.parametrize("test_module", test_modules) -def test_conv2d_tosa_BI(test_module): +def test_convolution_2d_tosa_BI(test_module): pipeline = TosaPipelineBI[input_t]( test_module, test_module.get_inputs(), aten_op, exir_op ) @@ -375,7 +375,7 @@ def test_conv2d_tosa_BI(test_module): @common.parametrize("test_module", test_modules) -def test_conv2d_u55_BI(test_module): +def test_convolution_2d_u55_BI(test_module): pipeline = EthosU55PipelineBI[input_t]( test_module, test_module.get_inputs(), aten_op, exir_op, run_on_fvp=False ) @@ -383,7 +383,7 @@ def test_conv2d_u55_BI(test_module): @common.parametrize("test_module", test_modules) -def test_conv2d_u85_BI(test_module): +def test_convolution_2d_u85_BI(test_module): pipeline = EthosU85PipelineBI[input_t]( test_module, test_module.get_inputs(), aten_op, exir_op, run_on_fvp=False ) @@ -392,7 +392,7 @@ def test_conv2d_u85_BI(test_module): @common.parametrize("test_module", test_modules, fvp_xfails) @common.SkipIfNoCorstone300 -def test_conv2d_u55_BI_on_fvp(test_module): +def test_convolution_2d_u55_BI_on_fvp(test_module): pipeline = EthosU55PipelineBI[input_t]( test_module, test_module.get_inputs(), aten_op, exir_op, run_on_fvp=True ) @@ -401,7 +401,7 @@ def test_conv2d_u55_BI_on_fvp(test_module): @common.parametrize("test_module", test_modules, fvp_xfails) @common.SkipIfNoCorstone320 -def test_conv2d_u85_BI_on_fvp(test_module): +def test_convolution_2d_u85_BI_on_fvp(test_module): pipeline = EthosU85PipelineBI[input_t]( test_module, test_module.get_inputs(), aten_op, exir_op, run_on_fvp=True ) @@ -443,7 +443,7 @@ def test_conv2d_u85_BI_on_fvp(test_module): @common.parametrize("module", reject_suite) -def test_reject_conv2d_u55_BI( +def test_reject_convolution_2d_u55_BI( module: Conv2d, ): ( diff --git a/backends/arm/test/ops/test_sub.py b/backends/arm/test/ops/test_sub.py index 0812f8a47a1..d1849e830c9 100644 --- a/backends/arm/test/ops/test_sub.py +++ b/backends/arm/test/ops/test_sub.py @@ -1,155 +1,219 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -# Copyright 2024 Arm Limited and/or its affiliates. # All rights reserved. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import unittest from typing import Tuple -import pytest - import torch -from executorch.backends.arm.test import common, conftest -from executorch.backends.arm.test.tester.arm_tester import ArmTester -from executorch.exir.backend.compile_spec_schema import CompileSpec -from parameterized import parameterized - - -class TestSub(unittest.TestCase): - class Sub(torch.nn.Module): - test_parameters = [ - (torch.ones(5),), - (3 * torch.ones(8),), - (10 * torch.randn(8),), - ] - - def forward(self, x): - return x - x - - class Sub2(torch.nn.Module): - test_parameters = [ - (torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1)), - ] - - def forward(self, x, y): - return x - y - - def _test_sub_tosa_MI_pipeline( - self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] - ): - ( - ArmTester( - module, - example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), - ) - .export() - .check_count({"torch.ops.aten.sub.Tensor": 1}) - .check_not(["torch.ops.quantized_decomposed"]) - .to_edge() - .partition() - .check_not(["torch.ops.aten.sub.Tensor"]) - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - .to_executorch() - .run_method_and_compare_outputs(inputs=test_data) - ) - - def _test_sub_tosa_BI_pipeline( - self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] - ): - ( - ArmTester( - module, - example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), - ) - .quantize() - .export() - .check_count({"torch.ops.aten.sub.Tensor": 1}) - .check(["torch.ops.quantized_decomposed"]) - .to_edge() - .partition() - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - .to_executorch() - .run_method_and_compare_outputs(inputs=test_data, qtol=1) - ) - - def _test_sub_ethosu_BI_pipeline( - self, - compile_spec: list[CompileSpec], - module: torch.nn.Module, - test_data: Tuple[torch.Tensor], - ): - tester = ( - ArmTester( - module, - example_inputs=test_data, - compile_spec=compile_spec, - ) - .quantize() - .export() - .check_count({"torch.ops.aten.sub.Tensor": 1}) - .check(["torch.ops.quantized_decomposed"]) - .to_edge() - .partition() - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - .to_executorch() - .serialize() - ) - if conftest.is_option_enabled("corstone_fvp"): - tester.run_method_and_compare_outputs(qtol=1, inputs=test_data) - - @parameterized.expand(Sub.test_parameters) - def test_sub_tosa_MI(self, test_data: torch.Tensor): - test_data = (test_data,) - self._test_sub_tosa_MI_pipeline(self.Sub(), test_data) - - @parameterized.expand(Sub.test_parameters) - def test_sub_tosa_BI(self, test_data: torch.Tensor): - test_data = (test_data,) - self._test_sub_tosa_BI_pipeline(self.Sub(), test_data) - - @parameterized.expand(Sub.test_parameters) - @pytest.mark.corstone_fvp - def test_sub_u55_BI(self, test_data: torch.Tensor): - test_data = (test_data,) - self._test_sub_ethosu_BI_pipeline( - common.get_u55_compile_spec(), self.Sub(), test_data - ) - - @parameterized.expand(Sub.test_parameters) - @pytest.mark.corstone_fvp - def test_sub_u85_BI(self, test_data: torch.Tensor): - test_data = (test_data,) - self._test_sub_ethosu_BI_pipeline( - common.get_u85_compile_spec(), self.Sub(), test_data - ) - - @parameterized.expand(Sub2.test_parameters) - def test_sub2_tosa_MI(self, operand1: torch.Tensor, operand2: torch.Tensor): - test_data = (operand1, operand2) - self._test_sub_tosa_MI_pipeline(self.Sub2(), test_data) - - @parameterized.expand(Sub2.test_parameters) - def test_sub2_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): - test_data = (operand1, operand2) - self._test_sub_tosa_BI_pipeline(self.Sub2(), test_data) - - @parameterized.expand(Sub2.test_parameters) - @pytest.mark.corstone_fvp - def test_sub2_u55_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): - test_data = (operand1, operand2) - self._test_sub_ethosu_BI_pipeline( - common.get_u55_compile_spec(), self.Sub2(), test_data - ) - - @parameterized.expand(Sub2.test_parameters) - @pytest.mark.corstone_fvp - def test_sub2_u85_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): - test_data = (operand1, operand2) - self._test_sub_ethosu_BI_pipeline( - common.get_u85_compile_spec(), self.Sub2(), test_data - ) +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineBI, + EthosU85PipelineBI, + TosaPipelineBI, + TosaPipelineMI, +) + +aten_op = "torch.ops.aten.sub.Tensor" +exir_op = "executorch_exir_dialects_edge__ops_aten_sub_Tensor" + +# Single-input subtraction (x - x) +sub_test_data = { + "ones_1D_5": (torch.ones(5),), + "ones_1D_50": (torch.ones(50),), + "rand_1D_10": (torch.rand(10),), + "rand_2D_5x5": (torch.rand(5, 5),), + "rand_3D_5x5x5": (torch.rand(5, 5, 5),), + "rand_4D_2x3x4x5": (torch.rand(2, 3, 4, 5),), + "zeros": (torch.zeros(10),), +} + +fvp_sub_xfails = {"rand_4D_2x3x4x5": "MLETORCH-517 : Multiple batches not supported"} + +# Two-input subtraction (x - y) +sub2_test_data = { + "rand_2D_4x4": (torch.rand(4, 4), torch.rand(4, 4)), + "rand_3D_4x4x4": (torch.rand(4, 2, 2), torch.rand(4, 2, 2)), + "rand_4D_2x2x4x4": (torch.rand(2, 2, 4, 4), torch.rand(2, 2, 4, 4)), + "zeros": (torch.rand(4, 4), torch.zeros(4, 4)), +} +fvp_sub2_xfails = {"rand_4D_2x2x4x4": "MLETORCH-517 : Multiple batches not supported"} + + +class Sub(torch.nn.Module): + def forward(self, x: torch.Tensor): + return x - x + + +class Sub2(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor): + return x - y + + +input_t1 = Tuple[torch.Tensor] # Input x +input_t2 = Tuple[torch.Tensor, torch.Tensor] # Input x, y + + +@common.parametrize("test_data", sub_test_data) +def test_sub_tosa_MI(test_data): + """Test Subtraction (TOSA MI)""" + pipeline = TosaPipelineMI[input_t1]( + Sub(), + test_data, + aten_op, + exir_op, + ) + pipeline.run() + + +@common.parametrize("test_data", sub2_test_data) +def test_sub_2_tosa_MI(test_data: Tuple[torch.Tensor, torch.Tensor]): + """Test Two-Operand Subtraction (TOSA MI)""" + pipeline = TosaPipelineMI[input_t2]( + Sub2(), + test_data, + aten_op, + exir_op, + ) + pipeline.run() + + +@common.parametrize("test_data", sub_test_data) +def test_sub_tosa_BI(test_data): + """Test Subtraction (TOSA BI)""" + pipeline = TosaPipelineBI[input_t1]( + Sub(), + test_data, + aten_op, + exir_op, + ) + pipeline.change_args("run_method_and_compare_outputs", qtol=1) + pipeline.run() + + +@common.parametrize("test_data", sub2_test_data) +def test_sub_2_tosa_BI(test_data: Tuple[torch.Tensor, torch.Tensor]): + """Test Two-Operand Subtraction (TOSA BI)""" + pipeline = TosaPipelineBI[input_t2]( + Sub2(), + test_data, + aten_op, + exir_op, + ) + pipeline.change_args("run_method_and_compare_outputs", qtol=1) + pipeline.run() + + +@common.parametrize("test_data", sub_test_data) +def test_sub_u55_BI(test_data): + """Test Subtraction on Ethos-U55""" + pipeline = EthosU55PipelineBI[input_t1]( + Sub(), + test_data, + aten_op, + exir_op, + run_on_fvp=False, + ) + pipeline.run() + + +@common.parametrize("test_data", sub2_test_data) +def test_sub_2_u55_BI(test_data: Tuple[torch.Tensor, torch.Tensor]): + """Test Two-Operand Subtraction on Ethos-U55""" + pipeline = EthosU55PipelineBI[input_t2]( + Sub2(), + test_data, + aten_op, + exir_op, + run_on_fvp=False, + ) + pipeline.run() + + +@common.parametrize("test_data", sub_test_data) +def test_sub_u85_BI(test_data): + """Test Subtraction on Ethos-U85 (Quantized Mode)""" + pipeline = EthosU85PipelineBI[input_t1]( + Sub(), + test_data, + aten_op, + exir_op, + run_on_fvp=False, + ) + pipeline.run() + + +@common.parametrize("test_data", sub2_test_data) +def test_sub_2_u85_BI(test_data: Tuple[torch.Tensor, torch.Tensor]): + """Test Two-Operand Subtraction on Ethos-U85""" + pipeline = EthosU85PipelineBI[input_t2]( + Sub2(), + test_data, + aten_op, + exir_op, + run_on_fvp=False, + ) + pipeline.run() + + +@common.parametrize("test_data", sub_test_data, fvp_sub_xfails) +@common.SkipIfNoCorstone300 +def test_sub_u55_BI_on_fvp(test_data): + """Test Subtraction on Ethos-U55 (FVP Mode)""" + pipeline = EthosU55PipelineBI[input_t1]( + Sub(), + test_data, + aten_op, + exir_op, + run_on_fvp=True, + ) + pipeline.change_args("run_method_and_compare_outputs", qtol=1) + pipeline.run() + + +@common.parametrize("test_data", sub2_test_data, fvp_sub2_xfails) +@common.SkipIfNoCorstone300 +def test_sub_2_u55_BI_on_fvp(test_data: Tuple[torch.Tensor, torch.Tensor]): + """Test Two-Operand Subtraction on Ethos-U55 (FVP Mode)""" + pipeline = EthosU55PipelineBI[input_t2]( + Sub2(), + test_data, + aten_op, + exir_op, + run_on_fvp=True, + ) + pipeline.change_args("run_method_and_compare_outputs", qtol=1) + pipeline.run() + + +@common.parametrize("test_data", sub_test_data, fvp_sub_xfails) +@common.SkipIfNoCorstone320 +def test_sub_u85_BI_on_fvp(test_data): + """Test Subtraction on Ethos-U85 (FVP Mode)""" + pipeline = EthosU85PipelineBI[input_t1]( + Sub(), + test_data, + aten_op, + exir_op, + run_on_fvp=True, + ) + pipeline.change_args("run_method_and_compare_outputs", qtol=1) + pipeline.run() + + +@common.parametrize("test_data", sub2_test_data, fvp_sub2_xfails) +@common.SkipIfNoCorstone320 +def test_sub_2_u85_BI_on_fvp(test_data: Tuple[torch.Tensor, torch.Tensor]): + """Test Two-Operand Subtraction on Ethos-U85 (FVP Mode)""" + pipeline = EthosU85PipelineBI[input_t2]( + Sub2(), + test_data, + aten_op, + exir_op, + run_on_fvp=True, + ) + pipeline.change_args("run_method_and_compare_outputs", qtol=1) + pipeline.run() diff --git a/backends/arm/test/tester/test_pipeline.py b/backends/arm/test/tester/test_pipeline.py index 62d0b633224..99166cd1b5e 100644 --- a/backends/arm/test/tester/test_pipeline.py +++ b/backends/arm/test/tester/test_pipeline.py @@ -7,8 +7,16 @@ from typing import Callable, Dict, Generic, List, Optional, Type, TypeVar import torch + +from executorch.backends.arm.quantizer.arm_quantizer import ( + EthosUQuantizer, + get_symmetric_quantization_config, + TOSAQuantizer, +) from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.arm_tester import ArmTester, RunPasses + +from executorch.backends.xnnpack.test.tester.tester import Quantize from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.pass_base import ExportPass from torch._export.pass_base import PassType @@ -263,12 +271,21 @@ def __init__( aten_op: str | List[str], exir_op: Optional[str | List[str]] = None, tosa_version: str = "TOSA-0.80+BI", + symmetric_io_quantization: bool = False, use_to_edge_transform_and_lower: bool = True, custom_path: str = None, ): compile_spec = common.get_tosa_compile_spec( tosa_version, custom_path=custom_path ) + quant_stage = ( + Quantize( + TOSAQuantizer(compile_spec).set_io(get_symmetric_quantization_config()), + get_symmetric_quantization_config(), + ) + if symmetric_io_quantization + else None + ) super().__init__( module, test_data, @@ -277,7 +294,8 @@ def __init__( exir_op, use_to_edge_transform_and_lower, ) - self.add_stage(self.tester.quantize, pos=0) + self.add_stage(self.tester.quantize, quant_stage, pos=0) + self.add_stage_after( "quantize", self.tester.check, @@ -385,10 +403,21 @@ def __init__( aten_ops: str | List[str], exir_ops: Optional[str | List[str]] = None, run_on_fvp: bool = False, + symmetric_io_quantization: bool = False, use_to_edge_transform_and_lower: bool = False, custom_path: str = None, ): compile_spec = common.get_u55_compile_spec(custom_path=custom_path) + quant_stage = ( + Quantize( + EthosUQuantizer(compile_spec).set_io( + get_symmetric_quantization_config() + ), + get_symmetric_quantization_config(), + ) + if symmetric_io_quantization + else None + ) super().__init__( module, test_data, @@ -397,7 +426,9 @@ def __init__( exir_ops, use_to_edge_transform_and_lower, ) - self.add_stage(self.tester.quantize, pos=0) + + self.add_stage(self.tester.quantize, quant_stage, pos=0) + self.add_stage_after( "quantize", self.tester.check, @@ -455,10 +486,21 @@ def __init__( aten_ops: str | List[str], exir_ops: str | List[str] = None, run_on_fvp: bool = False, + symmetric_io_quantization: bool = False, use_to_edge_transform_and_lower: bool = False, custom_path: str = None, ): compile_spec = common.get_u85_compile_spec(custom_path=custom_path) + quant_stage = ( + Quantize( + EthosUQuantizer(compile_spec).set_io( + get_symmetric_quantization_config() + ), + get_symmetric_quantization_config(), + ) + if symmetric_io_quantization + else None + ) super().__init__( module, test_data, @@ -467,7 +509,9 @@ def __init__( exir_ops, use_to_edge_transform_and_lower, ) - self.add_stage(self.tester.quantize, pos=0) + + self.add_stage(self.tester.quantize, quant_stage, pos=0) + self.add_stage_after( "quantize", self.tester.check,