diff --git a/backends/arm/_passes/insert_table_ops.py b/backends/arm/_passes/insert_table_ops.py index 6bb1ce7dce1..a5f66829da5 100644 --- a/backends/arm/_passes/insert_table_ops.py +++ b/backends/arm/_passes/insert_table_ops.py @@ -48,8 +48,6 @@ class TableOps: exir_ops.edge.aten.reciprocal.default: torch.reciprocal, exir_ops.edge.aten.rsqrt.default: torch.rsqrt, exir_ops.edge.aten.sigmoid.default: torch.sigmoid, - exir_ops.edge.aten.cos.default: torch.cos, - exir_ops.edge.aten.sin.default: torch.sin, exir_ops.edge.aten.tanh.default: torch.tanh, exir_ops.edge.aten.hardsigmoid.default: torch.nn.functional.hardsigmoid, exir_ops.edge.aten.hardswish.default: torch.nn.functional.hardswish, diff --git a/backends/arm/operator_support/__init__.py b/backends/arm/operator_support/__init__.py index 42ab9bca3cd..bd54c3e1f85 100644 --- a/backends/arm/operator_support/__init__.py +++ b/backends/arm/operator_support/__init__.py @@ -12,7 +12,6 @@ pool_2d_support, reduce_sum_support, right_shift_support, - sin_cos_support, slice_copy_support, to_copy_support, tosa_supported_operators, diff --git a/backends/arm/operator_support/sin_cos_support.py b/backends/arm/operator_support/sin_cos_support.py deleted file mode 100644 index 9dd63e8258d..00000000000 --- a/backends/arm/operator_support/sin_cos_support.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright 2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - - -import torch.fx as fx -from executorch.backends.arm.operator_support.tosa_supported_operators import ( - register_tosa_support_check, - SupportedTOSAOperatorCheck, -) -from executorch.backends.arm.tosa_specification import TosaSpecification -from executorch.exir.dialects._ops import ops as exir_ops - - -@register_tosa_support_check -class SinCosSupported(SupportedTOSAOperatorCheck): - targets = [ - exir_ops.edge.aten.cos.default, - exir_ops.edge.aten.sin.default, - ] - - tosa_specs = [ - TosaSpecification.create_from_string("TOSA-0.80+BI"), - TosaSpecification.create_from_string("TOSA-1.0+INT"), - TosaSpecification.create_from_string("TOSA-1.0+FP"), - ] - - def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): - return True diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index d8d1711e725..5de90bda252 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -23,11 +23,7 @@ EthosU55NotSupported, EthosU55TransposeCheck, ) -from executorch.backends.arm.tosa_specification import ( - Tosa_0_80, - Tosa_1_00, - TosaSpecification, -) +from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification from executorch.exir import ExportedProgram from executorch.exir.backend.utils import WhyNoPartitionReporter from executorch.exir.dialects._ops import ops as exir_ops @@ -128,9 +124,7 @@ def tosa_support_factory( if not tosa_spec.support_float(): negative_checks.append(NeedsDecompositionCheck(reporter)) negative_checks.append(CheckProperQuantization(reporter)) - if (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset) or ( - isinstance(tosa_spec, Tosa_1_00) and "u55" in tosa_spec.extensions - ): + if isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset: negative_checks.append(EthosU55NotSupported(reporter)) negative_checks.append(EthosU55DtypeSupport(reporter)) negative_checks.append(EthosU55TransposeCheck(reporter)) diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index e496fe74d54..3ee243779e6 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -18,7 +18,6 @@ op_clamp, op_constant_pad_nd, op_conv2d, - op_cos, op_eq, op_erf, op_exp, @@ -39,7 +38,6 @@ op_rshift_tensor, op_rsqrt, op_sigmoid, - op_sin, op_slice, op_sub, op_sum, diff --git a/backends/arm/operators/op_any.py b/backends/arm/operators/op_any.py index efb5b0b72b0..b65ebb2ac5d 100644 --- a/backends/arm/operators/op_any.py +++ b/backends/arm/operators/op_any.py @@ -4,8 +4,9 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe -from typing import Any, cast, List +from typing import cast, List +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( # type: ignore NodeVisitor, register_node_visitor, @@ -15,59 +16,17 @@ from torch.fx import Node -@register_node_visitor -class AnyVisitor_0_80(NodeVisitor): - target = "aten.any.dim" - - tosa_specs = NodeVisitor.tosa_specs_0_80 - - def define_node( - self, - node: Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - - if not (inputs[0].dtype == output.dtype): - raise ValueError( - "All inputs and outputs need same dtype." - f"Got {ts.DTypeNames[inputs[0].dtype]=}, {ts.DTypeNames[output.dtype]=}." - ) - if not (inputs[0].dtype == ts.DType.BOOL): - raise ValueError("All inputs need to be BOOL." f"Got {inputs[0].dtype=}") - - input_shape = list(inputs[0].shape) - dim = cast(int, inputs[1].number) % len( - input_shape - ) # process the negative index - keep_dim = cast(bool, inputs[2].number if len(inputs) > 2 else False) - if not keep_dim: - raise ValueError("This case should be handled by ConvertAnyDimDimsPass") - - attr = ts.TosaSerializerAttribute() - attr.AxisAttribute(inputs[0].dim_order.index(dim)) - - tosa_graph.addOperator( - ts.TosaOp.Op().REDUCE_ANY, [inputs[0].name], [output.name], attr - ) - - @register_node_visitor class AnyVisitor(NodeVisitor): target = "aten.any.dim" - tosa_specs = NodeVisitor.tosa_specs_1_00 - def define_node( self, node: Node, - tosa_graph: Any, + tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, ) -> None: - import serializer.tosa_serializer as ts if not (inputs[0].dtype == output.dtype): raise ValueError( @@ -86,7 +45,7 @@ def define_node( raise ValueError("This case should be handled by ConvertAnyDimDimsPass") attr = ts.TosaSerializerAttribute() - attr.ReduceAnyAttribute(inputs[0].dim_order.index(dim)) + attr.AxisAttribute(inputs[0].dim_order.index(dim)) tosa_graph.addOperator( ts.TosaOp.Op().REDUCE_ANY, [inputs[0].name], [output.name], attr diff --git a/backends/arm/operators/op_avg_pool2d.py b/backends/arm/operators/op_avg_pool2d.py index 73a6713633a..bdd3425fda5 100644 --- a/backends/arm/operators/op_avg_pool2d.py +++ b/backends/arm/operators/op_avg_pool2d.py @@ -4,10 +4,12 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe -from typing import Any, List +from typing import List import torch +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( get_input_qparams, get_output_qparams, @@ -34,16 +36,14 @@ def __init__(self, *args): def _build_generic_avgpool2d( self, node: torch.fx.Node, - tosa_graph: Any, + tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, input_zp: int, output_zp: int, - accumulator_type: Any, + accumulator_type: ts.DType, ) -> None: - import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - input_tensor = inputs[0] kernel_size_list = inputs[1].special stride_size_list = inputs[2].special @@ -79,12 +79,10 @@ def _build_generic_avgpool2d( def define_node( self, node: torch.fx.Node, - tosa_graph: Any, + tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, ) -> None: - import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - input_tensor = inputs[0] assert input_tensor.dtype == ts.DType.INT8 @@ -112,135 +110,10 @@ class AvgPool2dVisitor_0_80_MI(AvgPool2dVisitor_0_80_BI): def define_node( self, node: torch.fx.Node, - tosa_graph: Any, + tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, ) -> None: - import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - - assert ( - inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.FP32 - ), "Only FP32 and INT8 supported" - - if inputs[0].dtype == ts.DType.INT8: - super().define_node(node, tosa_graph, inputs, output) - - if inputs[0].dtype == ts.DType.FP32: - accumulator_type = ts.DType.FP32 - # Initilize zero point to zero. - input_zp = 0 - output_zp = 0 - - self._build_generic_avgpool2d( - node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type - ) - - -@register_node_visitor -class AvgPool2dVisitor(NodeVisitor): - target = "aten.avg_pool2d.default" - - tosa_specs = [ - TosaSpecification.create_from_string("TOSA-1.0+INT"), - ] - - def __init__(self, *args): - super().__init__(*args) - - def _build_generic_avgpool2d( - self, - node: torch.fx.Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - input_zp: int, - output_zp: int, - accumulator_type: Any, - ) -> None: - - import serializer.tosa_serializer as ts # type: ignore - - input_tensor = inputs[0] - kernel_size_list = inputs[1].special - stride_size_list = inputs[2].special - - try: - pad_size_list = inputs[3].special - pad_size_list = [ - pad_size_list[0], - pad_size_list[0], - pad_size_list[1], - pad_size_list[1], - ] - except IndexError: - pad_size_list = [0, 0, 0, 0] - - attr = ts.TosaSerializerAttribute() - attr.AvgPool2dAttribute( - kernel=kernel_size_list, - stride=stride_size_list, - pad=pad_size_list, - acc_type=accumulator_type, - ) - input_zp_tensor = tosa_graph.addConst( - shape=[1], dtype=output.dtype, vals=[input_zp] - ) - output_zp_tensor = tosa_graph.addConst( - shape=[1], dtype=output.dtype, vals=[output_zp] - ) - - tosa_graph.addOperator( - ts.TosaOp.Op().AVG_POOL2D, - [input_tensor.name, input_zp_tensor.name, output_zp_tensor.name], - [output.name], - attr, - ) - - def define_node( - self, - node: torch.fx.Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - import serializer.tosa_serializer as ts # type: ignore - - input_tensor = inputs[0] - assert input_tensor.dtype == ts.DType.INT8 - - accumulator_type = ts.DType.INT32 - - input_qargs = get_input_qparams(node) - input_zp = input_qargs[0].zp - - output_qargs = get_output_qparams(node) - output_zp = output_qargs[0].zp - - self._build_generic_avgpool2d( - node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type - ) - - -@register_node_visitor -class AvgPool2dVisitor_FP(AvgPool2dVisitor): - target = "aten.avg_pool2d.default" - - tosa_specs = [ - TosaSpecification.create_from_string("TOSA-1.0+FP"), - ] - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: torch.fx.Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - import serializer.tosa_serializer as ts # type: ignore - assert ( inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.FP32 ), "Only FP32 and INT8 supported" diff --git a/backends/arm/operators/op_cat.py b/backends/arm/operators/op_cat.py index bb77ba77940..6b1710301b1 100644 --- a/backends/arm/operators/op_cat.py +++ b/backends/arm/operators/op_cat.py @@ -5,8 +5,9 @@ # pyre-unsafe -from typing import Any, List +from typing import List +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -15,58 +16,20 @@ from torch.fx import Node -@register_node_visitor -class CatVisitor_0_80(NodeVisitor): - target = "aten.cat.default" - - tosa_specs = NodeVisitor.tosa_specs_0_80 - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - - tensors = inputs[0].special - dim = 0 if len(inputs) < 2 else inputs[1].number - rank = len(output.shape) - dim = (dim + rank) % rank - dim = output.dim_order.index(dim) - - attr = ts.TosaSerializerAttribute() - attr.AxisAttribute(dim) - - tosa_graph.addOperator( - ts.TosaOp.Op().CONCAT, - [tensor.name for tensor in tensors], - [output.name], - attr, - ) - - @register_node_visitor class CatVisitor(NodeVisitor): target = "aten.cat.default" - tosa_specs = NodeVisitor.tosa_specs_1_00 - def __init__(self, *args): super().__init__(*args) def define_node( self, node: Node, - tosa_graph: Any, + tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, ) -> None: - import serializer.tosa_serializer as ts tensors = inputs[0].special dim = 0 if len(inputs) < 2 else inputs[1].number @@ -75,7 +38,7 @@ def define_node( dim = output.dim_order.index(dim) attr = ts.TosaSerializerAttribute() - attr.ConcatAttribute(dim) + attr.AxisAttribute(dim) tosa_graph.addOperator( ts.TosaOp.Op().CONCAT, diff --git a/backends/arm/operators/op_constant_pad_nd.py b/backends/arm/operators/op_constant_pad_nd.py index 75cdc0b0fc4..b2c31df96ab 100644 --- a/backends/arm/operators/op_constant_pad_nd.py +++ b/backends/arm/operators/op_constant_pad_nd.py @@ -5,10 +5,12 @@ # pyre-unsafe -from typing import Any, List +from typing import List import torch +import tosa_tools.v0_80.serializer.tosa_serializer as ts + from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( get_input_qparams, ) @@ -17,27 +19,20 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_specification import TosaSpecification @register_node_visitor -class ConstantPadNDVisitor_0_80(NodeVisitor): +class ConstantPadNDVisitor(NodeVisitor): target = "aten.constant_pad_nd.default" - tosa_specs = [ - TosaSpecification.create_from_string("TOSA-0.80+BI"), - TosaSpecification.create_from_string("TOSA-0.80+MI"), - ] - def define_node( self, node: torch.fx.Node, - tosa_graph: Any, + tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, ) -> None: - import tosa_tools.v0_80.serializer.tosa_serializer as ts if inputs[0].dtype == ts.DType.INT8: input_qparams = get_input_qparams(node) @@ -79,72 +74,3 @@ def define_node( tosa_graph.addOperator( ts.TosaOp.Op().PAD, [inputs[0].name], [output.name], attr ) - - -@register_node_visitor -class ConstantPadNDVisitor(NodeVisitor): - - target = "aten.constant_pad_nd.default" - - tosa_specs = [ - TosaSpecification.create_from_string("TOSA-1.0+INT"), - TosaSpecification.create_from_string("TOSA-1.0+FP"), - ] - - def define_node( - self, - node: torch.fx.Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - - import serializer.tosa_serializer as ts # type: ignore - - if inputs[0].dtype == ts.DType.INT8: - input_qparams = get_input_qparams(node) - qargs = input_qparams[0] - pad_const_val = qargs.quantize_value(inputs[2].number).item() - pad_const_dtype = ts.DType.INT8 - else: - pad_const_val = inputs[2].number - pad_const_dtype = inputs[0].dtype - - rank = len(output.shape) - # Each dim needs 2 padding values. For example, to pad the last dimension, the pad has the form - # (padding_left, padding_right); to pad the last two dimensions, the pad has the form - # (padding_left, padding_right, padding_top, padding_bottom), and so on. For PyTorch NCHW format, the padding - # values are in the reverse order. So, firstly we need to reverse the input padding parameters. - input_pad = sum( - [ - [inputs[1].special[i], inputs[1].special[i + 1]] - for i in range(0, len(inputs[1].special), 2) - ][::-1], - [], - ) - # Then, add dummy zeros to make sure that both input_pad and output_pad has the same size. - input_pad = [0] * (rank * 2 - len(inputs[1].special)) + input_pad - # For PyTorch NCHW format, dim order is [0,...,rank-1] - input_dim_order = list(range(rank)) - output_pad = [0] * rank * 2 - - # Map input padding parameters into output padding parameters. TOSA is NHWC format. - for input_dim_idx, input_dim in enumerate(input_dim_order): - output_dim_idx = output.dim_order.index(input_dim) - output_pad[output_dim_idx * 2 : (output_dim_idx + 1) * 2] = input_pad[ - input_dim_idx * 2 : (input_dim_idx + 1) * 2 - ] - - padding = tosa_graph.addConst( - shape=[len(output_pad)], dtype=ts.DType.SHAPE, vals=output_pad - ) - - pad_const = tosa_graph.addConst( - shape=[1], dtype=pad_const_dtype, vals=[pad_const_val] - ) - - tosa_graph.addOperator( - ts.TosaOp.Op().PAD, - [inputs[0].name, padding.name, pad_const.name], - [output.name], - ) diff --git a/backends/arm/operators/op_cos.py b/backends/arm/operators/op_cos.py deleted file mode 100644 index 1fee25511ce..00000000000 --- a/backends/arm/operators/op_cos.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright 2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe -from typing import List - -import serializer.tosa_serializer as ts # type: ignore -from executorch.backends.arm.operators.node_visitor import ( - NodeVisitor, - register_node_visitor, -) -from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_specification import TosaSpecification -from torch.fx import Node - - -@register_node_visitor -class CosVisitor(NodeVisitor): - target = "aten.cos.default" - - # INT case should be handled by op_table - tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")] - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: Node, - tosa_graph: ts.TosaSerializer, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - if len(node.all_input_nodes) != 1: - raise ValueError( - f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" - ) - if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32: - raise ValueError( - f"Input and output for {self.target} need to be FP32, got input_dtype: " - f"{inputs[0].dtype} and output_dtype: {output.dtype}" - ) - - tosa_graph.addOperator(ts.TosaOp.Op().COS, [inputs[0].name], [output.name]) diff --git a/backends/arm/operators/op_erf.py b/backends/arm/operators/op_erf.py index e174069ee77..01243716129 100644 --- a/backends/arm/operators/op_erf.py +++ b/backends/arm/operators/op_erf.py @@ -3,9 +3,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # pyre-unsafe -from typing import Any, List +from typing import List import torch.fx +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore +import tosa_tools.v0_80.tosa.Op as TosaOp # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -27,42 +29,10 @@ def __init__(self, *args): def define_node( self, node: torch.fx.Node, - tosa_graph: Any, + tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, ) -> None: - import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - - if not (inputs[0].dtype == output.dtype): - raise ValueError( - "All inputs and output need same dtype." - f"Got {inputs[0].dtype=}, {output.dtype=}" - ) - if not (inputs[0].dtype == ts.DType.FP32): - raise ValueError("All inputs need to be FP32." f"Got {inputs[0].dtype=}") - # MI lowering - tosa_graph.addOperator(ts.TosaOp.Op().ERF, [inputs[0].name], [output.name]) - - -@register_node_visitor -class ERFVisitor(NodeVisitor): - target = "aten.erf.default" - - # INT case handled by op_table - tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")] - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: torch.fx.Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - import serializer.tosa_serializer as ts - if not (inputs[0].dtype == output.dtype): raise ValueError( "All inputs and output need same dtype." @@ -70,6 +40,5 @@ def define_node( ) if not (inputs[0].dtype == ts.DType.FP32): raise ValueError("All inputs need to be FP32." f"Got {inputs[0].dtype=}") - # MI lowering - tosa_graph.addOperator(ts.TosaOp.Op().ERF, [inputs[0].name], [output.name]) + tosa_graph.addOperator(TosaOp.Op().ERF, [inputs[0].name], [output.name]) diff --git a/backends/arm/operators/op_exp.py b/backends/arm/operators/op_exp.py index 60cc727d149..ca067b3b8be 100644 --- a/backends/arm/operators/op_exp.py +++ b/backends/arm/operators/op_exp.py @@ -4,8 +4,9 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe -from typing import Any, List +from typing import List +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -28,43 +29,10 @@ def __init__(self, *args): def define_node( self, node: Node, - tosa_graph: Any, + tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, ) -> None: - import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - - if len(node.all_input_nodes) != 1: - raise ValueError( - f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" - ) - if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32: - raise ValueError( - f"Input and output for {self.target} need to be FP32, got input dtype: " - f"{inputs[0].dtype} and output dtype: {output.dtype}" - ) - - tosa_graph.addOperator(ts.TosaOp.Op().EXP, [inputs[0].name], [output.name]) - - -@register_node_visitor -class ExpVisitor(NodeVisitor): - target = "aten.exp.default" - - # BI case should be handled by op_table - tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")] - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - import serializer.tosa_serializer as ts if len(node.all_input_nodes) != 1: raise ValueError( diff --git a/backends/arm/operators/op_log.py b/backends/arm/operators/op_log.py index b08bbcec003..9942cbf4702 100644 --- a/backends/arm/operators/op_log.py +++ b/backends/arm/operators/op_log.py @@ -4,8 +4,9 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe -from typing import Any, List +from typing import List +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -16,7 +17,7 @@ @register_node_visitor -class LogVisitor_0_80_MI(NodeVisitor): +class LogVisitor(NodeVisitor): target = "aten.log.default" # BI case should be handled by op_table @@ -28,44 +29,10 @@ def __init__(self, *args): def define_node( self, node: Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - - if len(node.all_input_nodes) != 1: - raise ValueError( - f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" - ) - if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32: - raise ValueError( - f"Input and output for {self.target} need to be FP32, got input_dtype: " - f"{inputs[0].dtype} and output_dtype: {output.dtype}" - ) - - tosa_graph.addOperator(ts.TosaOp.Op().LOG, [inputs[0].name], [output.name]) - - -@register_node_visitor -class LogVisitor(NodeVisitor): - target = "aten.log.default" - - # INT case should be handled by op_table - tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")] - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: Node, - tosa_graph: Any, + tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, ) -> None: - import serializer.tosa_serializer as ts - if len(node.all_input_nodes) != 1: raise ValueError( f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" diff --git a/backends/arm/operators/op_max_pool2d.py b/backends/arm/operators/op_max_pool2d.py index 928262aefc5..fcf2636977d 100644 --- a/backends/arm/operators/op_max_pool2d.py +++ b/backends/arm/operators/op_max_pool2d.py @@ -4,10 +4,12 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe -from typing import Any, List +from typing import List import torch +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( get_input_qparams, get_output_qparams, @@ -17,29 +19,22 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_specification import TosaSpecification @register_node_visitor -class MaxPool2dVisitor_0_80(NodeVisitor): +class MaxPool2dVisitor(NodeVisitor): target = "aten.max_pool2d.default" - tosa_specs = [ - TosaSpecification.create_from_string("TOSA-0.80+BI"), - TosaSpecification.create_from_string("TOSA-0.80+MI"), - ] - def __init__(self, *args): super().__init__(*args) def define_node( self, node: torch.fx.Node, - tosa_graph: Any, + tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, ) -> None: - import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore input_tensor = inputs[0] kernel_size = inputs[1].special @@ -85,53 +80,3 @@ def define_node( [output.name], attr, ) - - -@register_node_visitor -class MaxPool2dVisitor(NodeVisitor): - target = "aten.max_pool2d.default" - - tosa_specs = [ - TosaSpecification.create_from_string("TOSA-1.0+INT"), - TosaSpecification.create_from_string("TOSA-1.0+FP"), - ] - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: torch.fx.Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - - import serializer.tosa_serializer as ts # type: ignore - - input_tensor = inputs[0] - kernel_size = inputs[1].special - stride = inputs[2].special - - try: - pad_size_list = inputs[3].special - pad_size_list = [ - pad_size_list[0], - pad_size_list[0], - pad_size_list[1], - pad_size_list[1], - ] - except IndexError: - pad_size_list = [0, 0, 0, 0] - - attr = ts.TosaSerializerAttribute() - attr.MaxPool2dAttribute( - kernel=kernel_size, stride=stride, pad=pad_size_list, nan_mode=1 - ) - - tosa_graph.addOperator( - ts.TosaOp.Op().MAX_POOL2D, - [input_tensor.name], - [output.name], - attr, - ) diff --git a/backends/arm/operators/op_permute.py b/backends/arm/operators/op_permute.py index f3ea8b00961..c92a008a281 100644 --- a/backends/arm/operators/op_permute.py +++ b/backends/arm/operators/op_permute.py @@ -5,10 +5,11 @@ # pyre-unsafe -from typing import Any, List +from typing import List import torch +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -87,61 +88,20 @@ def transform_permutation_vector(permutation_vector: list[int], dim_order: list[ return permutation_vector -@register_node_visitor -class PermuteVisitor_0_80(NodeVisitor): - target = "aten.permute_copy.default" - - tosa_specs = NodeVisitor.tosa_specs_0_80 - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: torch.fx.Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - - # The permutation vector describes a permutation P in default Pytorch dim_order. - # For rank 4, the default dim_order NCHW. - # E.g. (2,3,0,1) -> permute (n,c,h,w) to (w,c,n,h) - permutation_vector = inputs[1].special - - if output.dim_order != tuple(range(len(output.dim_order))): - # the permutation vector can't be used directly if we are not in NCHW dim_order. - # Transform to dim_order. - permutation_vector = transform_permutation_vector( - permutation_vector, output.dim_order - ) - - attr = ts.TosaSerializerAttribute() - attr.TransposeAttribute(permutation_vector) - tosa_graph.addOperator( - ts.TosaOp.Op().TRANSPOSE, [inputs[0].name], [output.name], attr - ) - - @register_node_visitor class PermuteVisitor(NodeVisitor): target = "aten.permute_copy.default" - tosa_specs = NodeVisitor.tosa_specs_1_00 - def __init__(self, *args): super().__init__(*args) def define_node( self, node: torch.fx.Node, - tosa_graph: Any, + tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, ) -> None: - import serializer.tosa_serializer as ts - # The permutation vector describes a permutation P in default Pytorch dim_order. # For rank 4, the default dim_order NCHW. # E.g. (2,3,0,1) -> permute (n,c,h,w) to (w,c,n,h) diff --git a/backends/arm/operators/op_pow.py b/backends/arm/operators/op_pow.py index 781fce3c79f..d3b92feff12 100644 --- a/backends/arm/operators/op_pow.py +++ b/backends/arm/operators/op_pow.py @@ -5,8 +5,9 @@ # pyre-unsafe -from typing import Any, List +from typing import List +import tosa_tools.v0_80.serializer.tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -30,53 +31,10 @@ def __init__(self, *args): def define_node( self, node: Node, - tosa_graph: Any, + tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, ) -> None: - import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - - if not (inputs[0].dtype == inputs[1].dtype == output.dtype): - raise ValueError( - "All inputs and outputs need same dtype." - f"Got {inputs[0].dtype=}, {inputs[1].dtype=}, {output.dtype=}" - ) - if inputs[0].dtype not in [ts.DType.FP32, ts.DType.FP16]: - raise ValueError( - f"All inputs need to be FP32 or FP16. Got {inputs[0].dtype}" - ) - - tosa_graph.addOperator( - ts.TosaOp.Op().POW, - [ - inputs[0].name, - inputs[1].name, - ], - [output.name], - None, - ) - - -@register_node_visitor -class PowVisitor(NodeVisitor): - target = "aten.pow.Tensor_Tensor" - - tosa_specs = [ - TosaSpecification.create_from_string("TOSA-1.0+FP"), - ] - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - import serializer.tosa_serializer as ts - if not (inputs[0].dtype == inputs[1].dtype == output.dtype): raise ValueError( "All inputs and outputs need same dtype." diff --git a/backends/arm/operators/op_reciprocal.py b/backends/arm/operators/op_reciprocal.py index 7d1ee951993..c75fb99977e 100644 --- a/backends/arm/operators/op_reciprocal.py +++ b/backends/arm/operators/op_reciprocal.py @@ -4,10 +4,11 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe -from typing import Any, List +from typing import List import torch +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -29,46 +30,10 @@ def __init__(self, *args): def define_node( self, node: torch.fx.Node, - tosa_graph: Any, + tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, ) -> None: - import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - - if len(node.all_input_nodes) != 1: - raise ValueError( - f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" - ) - if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32: - raise ValueError( - f"Input and output for {self.target} need to be FP32, got " - f"{inputs[0].dtype=} and {output.dtype=}" - ) - - tosa_graph.addOperator( - ts.TosaOp.Op().RECIPROCAL, [inputs[0].name], [output.name] - ) - - -@register_node_visitor -class ReciprocalVisitor(NodeVisitor): - target = "aten.reciprocal.default" - - # INT case should be handled by op_table - tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")] - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: torch.fx.Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - import serializer.tosa_serializer as ts - if len(node.all_input_nodes) != 1: raise ValueError( f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" diff --git a/backends/arm/operators/op_rshift_tensor.py b/backends/arm/operators/op_rshift_tensor.py index 375dd76ba8d..125f5493a29 100644 --- a/backends/arm/operators/op_rshift_tensor.py +++ b/backends/arm/operators/op_rshift_tensor.py @@ -5,67 +5,34 @@ # pyre-unsafe -from typing import Any, List +from typing import List import torch +import tosa_tools.v0_80.serializer.tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_specification import Tosa_0_80, Tosa_1_00 - - -@register_node_visitor -class RshiftVisitor_0_80(NodeVisitor): - target = "aten.bitwise_right_shift.Tensor" - - tosa_specs = NodeVisitor.tosa_specs_0_80 - - def define_node( - self, - node: torch.fx.Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - - attr = ts.TosaSerializerAttribute() - round = False - if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset: - # U55 only supports INT32 and round == True - # TODO MLETORCH-525 Emulate round == False with different decomposition - round = True - attr.ArithmeticRightShiftAttribute(round=round) - - tosa_graph.addOperator( - ts.TosaOp.Op().ARITHMETIC_RIGHT_SHIFT, - [inputs[0].name, inputs[1].name], - [output.name], - attr, - ) +from executorch.backends.arm.tosa_specification import Tosa_0_80 @register_node_visitor class RshiftVisitor(NodeVisitor): target = "aten.bitwise_right_shift.Tensor" - tosa_specs = NodeVisitor.tosa_specs_1_00 - def define_node( self, node: torch.fx.Node, - tosa_graph: Any, + tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, ) -> None: - import serializer.tosa_serializer as ts attr = ts.TosaSerializerAttribute() round = False - if isinstance(self.tosa_spec, Tosa_1_00) and "u55" in self.tosa_spec.extensions: + if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset: # U55 only supports INT32 and round == True # TODO MLETORCH-525 Emulate round == False with different decomposition round = True diff --git a/backends/arm/operators/op_rsqrt.py b/backends/arm/operators/op_rsqrt.py index 784c4b4d257..e3937f8c44a 100644 --- a/backends/arm/operators/op_rsqrt.py +++ b/backends/arm/operators/op_rsqrt.py @@ -4,10 +4,11 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe -from typing import Any, List +from typing import List import torch +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -29,44 +30,10 @@ def __init__(self, *args): def define_node( self, node: torch.fx.Node, - tosa_graph: Any, + tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, ) -> None: - import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - - if len(node.all_input_nodes) != 1: - raise ValueError( - f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" - ) - if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32: - raise ValueError( - f"Input and output for {self.target} need to be FP32, got " - f"{inputs[0].dtype=} and {output.dtype=}" - ) - - tosa_graph.addOperator(ts.TosaOp.Op().RSQRT, [inputs[0].name], [output.name]) - - -@register_node_visitor -class RsqrtVisitor(NodeVisitor): - target = "aten.rsqrt.default" - - # INT case should be handled by op_table - tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")] - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: torch.fx.Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - import serializer.tosa_serializer as ts - if len(node.all_input_nodes) != 1: raise ValueError( f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" diff --git a/backends/arm/operators/op_sigmoid.py b/backends/arm/operators/op_sigmoid.py index a43e9ae798f..9a002036fee 100644 --- a/backends/arm/operators/op_sigmoid.py +++ b/backends/arm/operators/op_sigmoid.py @@ -4,8 +4,9 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe -from typing import Any, List +from typing import List +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -28,43 +29,10 @@ def __init__(self, *args): def define_node( self, node: Node, - tosa_graph: Any, + tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, ) -> None: - import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - - if len(node.all_input_nodes) != 1: - raise ValueError( - f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" - ) - if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32: - raise ValueError( - f"Input and output for {self.target} need to be FP32, got input_dtype: " - f"{inputs[0].dtype} and output_dtype: {output.dtype}" - ) - - tosa_graph.addOperator(ts.TosaOp.Op().SIGMOID, [inputs[0].name], [output.name]) - - -@register_node_visitor -class SigmoidVisitor(NodeVisitor): - target = "aten.sigmoid.default" - - # INT case should be handled by op_table - tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")] - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - import serializer.tosa_serializer as ts if len(node.all_input_nodes) != 1: raise ValueError( diff --git a/backends/arm/operators/op_sin.py b/backends/arm/operators/op_sin.py deleted file mode 100644 index ee444c38f37..00000000000 --- a/backends/arm/operators/op_sin.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright 2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe -from typing import List - -import serializer.tosa_serializer as ts # type: ignore -from executorch.backends.arm.operators.node_visitor import ( - NodeVisitor, - register_node_visitor, -) -from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_specification import TosaSpecification -from torch.fx import Node - - -@register_node_visitor -class SinVisitor(NodeVisitor): - target = "aten.sin.default" - - # INT case should be handled by op_table - tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")] - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: Node, - tosa_graph: ts.TosaSerializer, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - if len(node.all_input_nodes) != 1: - raise ValueError( - f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" - ) - if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32: - raise ValueError( - f"Input and output for {self.target} need to be FP32, got input_dtype: " - f"{inputs[0].dtype} and output_dtype: {output.dtype}" - ) - - tosa_graph.addOperator(ts.TosaOp.Op().SIN, [inputs[0].name], [output.name]) diff --git a/backends/arm/operators/op_tanh.py b/backends/arm/operators/op_tanh.py index 01af36c4d37..51cf1ee786b 100644 --- a/backends/arm/operators/op_tanh.py +++ b/backends/arm/operators/op_tanh.py @@ -4,8 +4,9 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe -from typing import Any, List +from typing import List +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -16,7 +17,7 @@ @register_node_visitor -class TanhVisitor_0_80_MI(NodeVisitor): +class TanhVisitor_080_MI(NodeVisitor): target = "aten.tanh.default" # BI case should be handled by op_table @@ -28,44 +29,10 @@ def __init__(self, *args): def define_node( self, node: Node, - tosa_graph: Any, + tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, ) -> None: - import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - - if len(node.all_input_nodes) != 1: - raise ValueError( - f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" - ) - if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32: - raise ValueError( - f"Input and output for {self.target} need to be FP32, got input_dtype: " - f"{inputs[0].dtype} and output_dtype: {output.dtype}" - ) - - tosa_graph.addOperator(ts.TosaOp.Op().TANH, [inputs[0].name], [output.name]) - - -@register_node_visitor -class TanhVisitor(NodeVisitor): - target = "aten.tanh.default" - - # INT case should be handled by op_table - tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")] - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - import serializer.tosa_serializer as ts - if len(node.all_input_nodes) != 1: raise ValueError( f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" diff --git a/backends/arm/operators/op_upsample_nearest2d.py b/backends/arm/operators/op_upsample_nearest2d.py index 63694b715f0..23d24b78339 100644 --- a/backends/arm/operators/op_upsample_nearest2d.py +++ b/backends/arm/operators/op_upsample_nearest2d.py @@ -4,10 +4,11 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe -from typing import Any, List +from typing import List import torch +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -19,23 +20,19 @@ @register_node_visitor -class UpsampleNearest2dVisitor_0_80(NodeVisitor): +class UpsampleNearest2dVisitor(NodeVisitor): target = "aten.upsample_nearest2d.vec" - tosa_specs = NodeVisitor.tosa_specs_0_80 - def __init__(self, *args): super().__init__(*args) def define_node( self, node: torch.fx.Node, - tosa_graph: Any, + tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, ) -> None: - import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - assert ( inputs[0].shape is not None and output.shape is not None ), "Only static shapes are supported" @@ -70,74 +67,3 @@ def in_int16_range(x): tosa_graph.addOperator( ts.TosaOp.Op().RESIZE, [inputs[0].name], [output.name], attr ) - - -@register_node_visitor -class UpsampleNearest2dVisitor(NodeVisitor): - target = "aten.upsample_nearest2d.vec" - - tosa_specs = NodeVisitor.tosa_specs_1_00 - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: torch.fx.Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - import serializer.tosa_serializer as ts - - assert ( - inputs[0].shape is not None and output.shape is not None - ), "Only static shapes are supported" - - # tosa_shape output is NHWC, take HW - input_size_yx = torch.tensor( - tosa_shape(inputs[0].shape, inputs[0].dim_order)[1:3] - ) - # Ignore scale and size parameters, directly use the output size as - # we only support static shapes currently - output_size_yx = torch.tensor(tosa_shape(output.shape, output.dim_order)[1:3]) - - scale_n_yx, scale_d_yx, offset_yx, border_yx = get_resize_parameters( - input_size_yx, output_size_yx, ResizeMode.NEAREST, align_corners=True - ) - - def in_int16_range(x): - return torch.all(x >= -(2**15)) and torch.all(x <= 2**15 - 1) - - assert in_int16_range(scale_n_yx) - assert in_int16_range(scale_d_yx) - assert in_int16_range(border_yx) - - scales = [scale_n_yx[0], scale_d_yx[0], scale_n_yx[1], scale_d_yx[1]] - scales_tensor = tosa_graph.addConst( - [len(scales)], ts.DType.SHAPE, scales, node.name + "_scales" - ) - offset = offset_yx.tolist() - offset_tensor = tosa_graph.addConst( - [len(offset)], ts.DType.SHAPE, offset, node.name + "_offset" - ) - border = border_yx.tolist() - border_tensor = tosa_graph.addConst( - [len(border)], ts.DType.SHAPE, border, node.name + "_border" - ) - attr = ts.TosaSerializerAttribute() - attr.ResizeAttribute( - mode=ResizeMode.NEAREST, - ) - - tosa_graph.addOperator( - ts.TosaOp.Op().RESIZE, - [ - inputs[0].name, - scales_tensor.name, - offset_tensor.name, - border_tensor.name, - ], - [output.name], - attr, - ) diff --git a/backends/arm/operators/op_where.py b/backends/arm/operators/op_where.py index b58fda1c399..ba2469e74e1 100644 --- a/backends/arm/operators/op_where.py +++ b/backends/arm/operators/op_where.py @@ -3,7 +3,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, List, Sequence +from typing import List, Sequence + +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore +import tosa_tools.v0_80.tosa.Op as TosaOp # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -14,148 +17,54 @@ from torch.fx import Node -@register_node_visitor -class WhereVisitor_0_80_BI(NodeVisitor): - target = "aten.where.self" - - tosa_specs = [ - TosaSpecification.create_from_string("TOSA-0.80+BI"), - ] - - def __init__(self, *args): - super().__init__(*args) - - def _add_node_to_tosa_graph( - self, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - supported_dtypes: Sequence, - ) -> None: - import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - - if len(inputs) != 3: - raise ValueError(f"aten.where.self expects 3 arguments, got {len(inputs)}") - - if inputs[0].dtype is not ts.DType.BOOL: - raise ValueError("Input 0 needs to have dtype BOOL") - if inputs[1].dtype != inputs[2].dtype: +def _add_node_to_tosa_graph( + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + supported_dtypes: Sequence, +) -> None: + if len(inputs) != 3: + raise ValueError(f"aten.where.self expects 3 arguments, got {len(inputs)}") + + if inputs[0].dtype is not ts.DType.BOOL: + raise ValueError("Input 0 needs to have dtype BOOL") + if inputs[1].dtype != inputs[2].dtype: + raise ValueError( + "Non-condition tensors must have same data type, got " + f"{inputs[1].dtype} and {inputs[2].dtype}" + ) + for input_ in inputs[1:]: + if input_.dtype not in supported_dtypes: raise ValueError( - "Non-condition tensors must have same data type, got " - f"{inputs[1].dtype} and {inputs[2].dtype}" + f"Input needs to be of torch dtype {supported_dtypes}, got {input_.dtype}" ) - for input_ in inputs[1:]: - if input_.dtype not in supported_dtypes: - raise ValueError( - f"Input needs to be of torch dtype {supported_dtypes}, got {input_.dtype}" - ) - - tosa_graph.addOperator( - ts.TosaOp.Op().SELECT, - [inputs[0].name, inputs[1].name, inputs[2].name], - [output.name], - None, - ) - def define_node( - self, - node: Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - - bi_supported_dtypes = [ - ts.DType.INT8, - ts.DType.INT16, - ts.DType.INT32, - ts.DType.BOOL, - ] - self._add_node_to_tosa_graph(tosa_graph, inputs, output, bi_supported_dtypes) + tosa_graph.addOperator( + TosaOp.Op().SELECT, + [inputs[0].name, inputs[1].name, inputs[2].name], + [output.name], + None, + ) @register_node_visitor -class WhereVisitor_0_80_MI(WhereVisitor_0_80_BI): - - tosa_specs = [ - TosaSpecification.create_from_string("TOSA-0.80+MI"), - ] - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - - mi_supported_dtypes = [ - ts.DType.FP16, - ts.DType.FP32, - ts.DType.INT8, - ts.DType.INT16, - ts.DType.INT32, - ts.DType.BOOL, - ] - self._add_node_to_tosa_graph(tosa_graph, inputs, output, mi_supported_dtypes) - - -@register_node_visitor -class WhereVisitor_INT(NodeVisitor): +class WhereVisitor_080_BI(NodeVisitor): target = "aten.where.self" tosa_specs = [ - TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-0.80+BI"), ] def __init__(self, *args): super().__init__(*args) - def _add_node_to_tosa_graph( - self, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - supported_dtypes: Sequence, - ) -> None: - import serializer.tosa_serializer as ts - - if len(inputs) != 3: - raise ValueError(f"aten.where.self expects 3 arguments, got {len(inputs)}") - - if inputs[0].dtype is not ts.DType.BOOL: - raise ValueError("Input 0 needs to have dtype BOOL") - if inputs[1].dtype != inputs[2].dtype: - raise ValueError( - "Non-condition tensors must have same data type, got " - f"{inputs[1].dtype} and {inputs[2].dtype}" - ) - for input_ in inputs[1:]: - if input_.dtype not in supported_dtypes: - raise ValueError( - f"Input needs to be of torch dtype {supported_dtypes}, got {input_.dtype}" - ) - - tosa_graph.addOperator( - ts.TosaOp.Op().SELECT, - [inputs[0].name, inputs[1].name, inputs[2].name], - [output.name], - None, - ) - def define_node( self, node: Node, - tosa_graph: Any, + tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, ) -> None: - import serializer.tosa_serializer as ts bi_supported_dtypes = [ ts.DType.INT8, @@ -163,14 +72,14 @@ def define_node( ts.DType.INT32, ts.DType.BOOL, ] - self._add_node_to_tosa_graph(tosa_graph, inputs, output, bi_supported_dtypes) + _add_node_to_tosa_graph(tosa_graph, inputs, output, bi_supported_dtypes) @register_node_visitor -class WhereVisitor_FP(WhereVisitor_INT): +class WhereVisitor_080_MI(WhereVisitor_080_BI): tosa_specs = [ - TosaSpecification.create_from_string("TOSA-1.0+FP"), + TosaSpecification.create_from_string("TOSA-0.80+MI"), ] def __init__(self, *args): @@ -179,12 +88,10 @@ def __init__(self, *args): def define_node( self, node: Node, - tosa_graph: Any, + tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, ) -> None: - import serializer.tosa_serializer as ts - mi_supported_dtypes = [ ts.DType.FP16, ts.DType.FP32, @@ -193,4 +100,4 @@ def define_node( ts.DType.INT32, ts.DType.BOOL, ] - self._add_node_to_tosa_graph(tosa_graph, inputs, output, mi_supported_dtypes) + _add_node_to_tosa_graph(tosa_graph, inputs, output, mi_supported_dtypes) diff --git a/backends/arm/operators/ops_binary.py b/backends/arm/operators/ops_binary.py index 425007bab3c..a17da41f767 100644 --- a/backends/arm/operators/ops_binary.py +++ b/backends/arm/operators/ops_binary.py @@ -5,11 +5,13 @@ # pyre-unsafe -from typing import Any, List +from typing import List import torch import torch.fx +import tosa_tools.v0_80.serializer.tosa_serializer as ts + from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -17,50 +19,19 @@ from executorch.backends.arm.tosa_mapping import TosaArg -def binary_operator_factory_0_80(bw_target: str, tosa_op): - """Creates and registers NodeVisitors for operators that have two inputs and map directly to a TOSA op.""" - - class BinaryOperator_0_80(NodeVisitor): - target = bw_target - tosa_specs = NodeVisitor.tosa_specs_0_80 - - def define_node( - self, - node: torch.fx.Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore # noqa: F401 - - if not (inputs[0].dtype == inputs[1].dtype == output.dtype): - raise ValueError( - "All inputs and outputs need same dtype." - f"Got {inputs[0].dtype=}, {inputs[1].dtype=}, {output.dtype=}." - ) - - tosa_graph.addOperator( - tosa_op, [inputs[0].name, inputs[1].name], [output.name] - ) - - register_node_visitor(BinaryOperator_0_80) - - def binary_operator_factory(bw_target: str, tosa_op): """Creates and registers NodeVisitors for operators that have two inputs and map directly to a TOSA op.""" class BinaryOperator(NodeVisitor): target = bw_target - tosa_specs = NodeVisitor.tosa_specs_1_00 def define_node( self, node: torch.fx.Node, - tosa_graph: Any, + tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, ) -> None: - import serializer.tosa_serializer as ts # type: ignore # noqa: F401 if not (inputs[0].dtype == inputs[1].dtype == output.dtype): raise ValueError( @@ -75,20 +46,6 @@ def define_node( register_node_visitor(BinaryOperator) -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - -binary_operator_factory_0_80("aten.bitwise_and.Tensor", ts.TosaOp.Op().BITWISE_AND) -binary_operator_factory_0_80("aten.bitwise_xor.Tensor", ts.TosaOp.Op().BITWISE_XOR) -binary_operator_factory_0_80("aten.bitwise_or.Tensor", ts.TosaOp.Op().BITWISE_OR) -binary_operator_factory_0_80("aten.logical_and.default", ts.TosaOp.Op().LOGICAL_AND) -binary_operator_factory_0_80("aten.logical_xor.default", ts.TosaOp.Op().LOGICAL_XOR) -binary_operator_factory_0_80("aten.logical_or.default", ts.TosaOp.Op().LOGICAL_OR) -binary_operator_factory_0_80( - "aten.bitwise_left_shift.Tensor", ts.TosaOp.Op().LOGICAL_LEFT_SHIFT -) - -import serializer.tosa_serializer as ts # type: ignore - binary_operator_factory("aten.bitwise_and.Tensor", ts.TosaOp.Op().BITWISE_AND) binary_operator_factory("aten.bitwise_xor.Tensor", ts.TosaOp.Op().BITWISE_XOR) binary_operator_factory("aten.bitwise_or.Tensor", ts.TosaOp.Op().BITWISE_OR) diff --git a/backends/arm/operators/ops_unary.py b/backends/arm/operators/ops_unary.py index 3bb2be16585..3f713e086e6 100644 --- a/backends/arm/operators/ops_unary.py +++ b/backends/arm/operators/ops_unary.py @@ -4,10 +4,11 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe -from typing import Any, List +from typing import List import torch.fx +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -16,44 +17,6 @@ from executorch.backends.arm.tosa_mapping import TosaArg -def unary_operator_factory_0_80(unary_target: str, tosa_op): - "Creates and registers NodeVisitors for operations that have one input and map directly into a TOSA op." - - # Some TOSA unary operators only support float - fp_only_ops = ["aten.floor.default"] - - class UnaryOperator_0_80(NodeVisitor): - target = unary_target - tosa_specs = NodeVisitor.tosa_specs_0_80 - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: torch.fx.Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore # noqa: F401 - - if not (inputs[0].dtype == output.dtype): - raise ValueError( - "All inputs and output need same dtype." - f"Got {inputs[0].dtype=}, {output.dtype=}" - ) - - if self.target in fp_only_ops and not (inputs[0].dtype == ts.DType.FP32): - raise ValueError( - "All inputs need to be FP32." f"Got {inputs[0].dtype=}" - ) - - tosa_graph.addOperator(tosa_op, [inputs[0].name], [output.name]) - - register_node_visitor(UnaryOperator_0_80) - - def unary_operator_factory(unary_target: str, tosa_op): "Creates and registers NodeVisitors for operations that have one input and map directly into a TOSA op." @@ -62,7 +25,6 @@ def unary_operator_factory(unary_target: str, tosa_op): class UnaryOperator(NodeVisitor): target = unary_target - tosa_specs = NodeVisitor.tosa_specs_1_00 def __init__(self, *args): super().__init__(*args) @@ -70,11 +32,10 @@ def __init__(self, *args): def define_node( self, node: torch.fx.Node, - tosa_graph: Any, + tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, ) -> None: - import serializer.tosa_serializer as ts # type: ignore # noqa: F401 if not (inputs[0].dtype == output.dtype): raise ValueError( @@ -92,14 +53,6 @@ def define_node( register_node_visitor(UnaryOperator) -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - -unary_operator_factory_0_80("aten.ceil.default", ts.TosaOp.Op().CEIL) -unary_operator_factory_0_80("aten.floor.default", ts.TosaOp.Op().FLOOR) -unary_operator_factory_0_80("aten.logical_not.default", ts.TosaOp.Op().LOGICAL_NOT) - -import serializer.tosa_serializer as ts # type: ignore - unary_operator_factory("aten.ceil.default", ts.TosaOp.Op().CEIL) unary_operator_factory("aten.floor.default", ts.TosaOp.Op().FLOOR) unary_operator_factory("aten.logical_not.default", ts.TosaOp.Op().LOGICAL_NOT) diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index 5ac747177be..d24ac0c887e 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -177,8 +177,6 @@ def _match_pattern( torch.ops.aten.reciprocal.default, torch.ops.aten.rsqrt.default, torch.ops.aten.sigmoid.default, - torch.ops.aten.cos.default, - torch.ops.aten.sin.default, torch.ops.aten.tanh.default, torch.ops.aten.sum.dim_IntList, torch.ops.aten.hardsigmoid.default, diff --git a/backends/arm/test/misc/test_multiple_delegates.py b/backends/arm/test/misc/test_multiple_delegates.py index ab768d273c6..9103e8ca899 100644 --- a/backends/arm/test/misc/test_multiple_delegates.py +++ b/backends/arm/test/misc/test_multiple_delegates.py @@ -20,7 +20,7 @@ def get_inputs(self): def forward(self, x: torch.Tensor, y: torch.Tensor): z = x + y - s = torch.tan(z) + s = torch.sin(z) return s * z def test_tosa_MI(self): diff --git a/backends/arm/test/ops/test_constant_pad_nd.py b/backends/arm/test/ops/test_constant_pad_nd.py index 9a19f6fbf5f..3dfd28640eb 100644 --- a/backends/arm/test/ops/test_constant_pad_nd.py +++ b/backends/arm/test/ops/test_constant_pad_nd.py @@ -2,74 +2,143 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + # # Test the pad_constant_nd op which pads the input tensor at specific dimension(s). # +import unittest from typing import Tuple import torch +import torch.nn as nn import torch.nn.functional as F from executorch.backends.arm.test import common -from executorch.backends.arm.test.tester.test_pipeline import ( - TosaPipelineBI, - TosaPipelineMI, -) +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from parameterized import parameterized + +test_data_suite = [ + ("4dim_last1dim", torch.rand(1, 1, 16, 16), (1, 1, 0, 0, 0, 0, 0, 0), 1), + ("4dim_last2dim", torch.rand(1, 1, 16, 16), (1, 0, 1, 0, 0, 0, 0, 0), 2), + ("4dim_last3dim", torch.rand(1, 1, 16, 16), (1, 1, 0, 2, 0, 2, 0, 0), 3), + ("4dim_last4dim", torch.rand(1, 1, 16, 16), (1, 0, 1, 1, 0, 2, 0, 2), 4), + ("3dim_last1dim", torch.rand(1, 1, 16), (1, 1, 0, 0, 0, 0), 1), + ("3dim_last2dim", torch.rand(1, 1, 16), (1, 0, 1, 1, 0, 0), 2), + ("3dim_last3dim", torch.rand(1, 1, 16), (1, 0, 1, 0, 1, 1), 3), + ("2dim_last1dim", torch.rand(1, 1, 16), (1, 1, 0, 0), 1), + ("2dim_last2dim", torch.rand(1, 1, 16), (1, 0, 1, 1), 2), +] + + +class TestConstantPadND(unittest.TestCase): + """Tests pad.""" + + class ConstantPadND(torch.nn.Module): + def __init__(self, pad: Tuple, value: float | None = None): + super().__init__() + self.dim = len(pad) // 2 + self.value = value + in_channels = 1 + # Only apply conv2d when the input dim = 4. + if self.dim == 4: + in_channels += pad[-3] + pad[-4] + + self.conv2d = nn.Conv2d( + in_channels=in_channels, + out_channels=3, + kernel_size=3, + bias=True, + stride=(2, 2), + padding=0, + ) -aten_op = "torch.ops.aten.pad.default" -exir_op = "executorch_exir_dialects_edge__ops_aten_pad_default" -input_t1 = Tuple[torch.Tensor] # Input x -test_data_suite = { - "4dim_last1dim": (torch.rand(1, 1, 16, 16), (1, 1, 0, 0, 0, 0, 0, 0), 1), - "4dim_last2dim": (torch.rand(1, 1, 16, 16), (1, 0, 1, 0, 0, 0, 0, 0), 2), - "4dim_last3dim": (torch.rand(1, 1, 16, 16), (1, 1, 0, 2, 0, 2, 0, 0), 3), - "4dim_last4dim": (torch.rand(1, 1, 16, 16), (1, 0, 1, 1, 0, 2, 0, 2), 4), - "3dim_last1dim": (torch.rand(1, 1, 16), (1, 1, 0, 0, 0, 0), 1), - "3dim_last2dim": (torch.rand(1, 1, 16), (1, 0, 1, 1, 0, 0), 2), - "3dim_last3dim": (torch.rand(1, 1, 16), (1, 0, 1, 0, 1, 1), 3), - "2dim_last1dim": (torch.rand(1, 1, 16), (1, 1, 0, 0), 1), - "2dim_last2dim": (torch.rand(1, 1, 16), (1, 0, 1, 1), 2), -} -"""Tests pad.""" + in_channels = 3 + in_channels += pad[-3] + pad[-4] + self.conv2d_1 = nn.Conv2d( + in_channels=in_channels, + out_channels=3, + kernel_size=3, + bias=True, + padding="same", + ) + nonzero_idx = len(pad) + for i in range(0, len(pad), 2): + if pad[i] + pad[i + 1] == 0: + nonzero_idx = i + break + self.pad = pad[:nonzero_idx] + self.relu = nn.ReLU() + self.sigmoid = nn.Sigmoid() -class ConstantPadND(torch.nn.Module): - def __init__(self, pad: Tuple, value: float | None = None): - super().__init__() - self.value = value - nonzero_idx = len(pad) - for i in range(0, len(pad), 2): - if pad[i] + pad[i + 1] == 0: - nonzero_idx = i - break - self.pad = pad[:nonzero_idx] + def forward(self, x: torch.Tensor): + x = F.pad(x, pad=self.pad, mode="constant", value=self.value) + if self.dim == 4: + x = self.conv2d(x) + x = self.relu(x) - def forward(self, x: torch.Tensor): - x = F.pad(x, pad=self.pad, mode="constant", value=self.value) - return x + x = F.pad(x, pad=self.pad, mode="constant", value=self.value) + if self.dim == 4: + x = self.conv2d_1(x) + x = self.sigmoid(x) + return x + def _test_constant_pad_nd_tosa_MI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), + ) + .export() + .check_count({"torch.ops.aten.pad.default": 2}) + .to_edge() + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) -@common.parametrize( - "test_data", - test_data_suite, -) -def test_constant_pad_nd_tosa_MI(test_data: Tuple): - test_data, padding, value = test_data - pipeline = TosaPipelineMI[input_t1]( - ConstantPadND(padding, value), - (test_data,), - aten_op, - exir_op, - ) - pipeline.run() + def _test_constant_pad_nd_tosa_BI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), + ) + .quantize() + .export() + .check_count({"torch.ops.aten.pad.default": 2}) + .to_edge() + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data, qtol=1) + ) + @parameterized.expand(test_data_suite) + def test_constant_pad_nd_tosa_MI( + self, + test_name: str, + test_data: torch.Tensor, + padding: Tuple, + value: float | None = None, + ): + self._test_constant_pad_nd_tosa_MI_pipeline( + self.ConstantPadND(padding, value), (test_data,) + ) -@common.parametrize("test_data", test_data_suite) -def test_constant_pad_nd_tosa_BI(test_data: Tuple): - test_data, padding, value = test_data - pipeline = TosaPipelineBI[input_t1]( - ConstantPadND(padding, value), - (test_data,), - aten_op, - exir_op, - ) - pipeline.run() + @parameterized.expand(test_data_suite) + def test_constant_pad_nd_tosa_BI( + self, + test_name: str, + test_data: torch.Tensor, + padding: Tuple, + value: float | None = None, + ): + self._test_constant_pad_nd_tosa_BI_pipeline( + self.ConstantPadND(padding, value), (test_data,) + ) diff --git a/backends/arm/test/ops/test_conv_constant_pad_nd.py b/backends/arm/test/ops/test_conv_constant_pad_nd.py deleted file mode 100644 index 026b5d1dc4d..00000000000 --- a/backends/arm/test/ops/test_conv_constant_pad_nd.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright 2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# -# Test the pad_constant_nd op which pads the input tensor at specific dimension(s). -# - -from typing import Tuple - -import torch -import torch.nn as nn -import torch.nn.functional as F -from executorch.backends.arm.test import common -from executorch.backends.arm.test.tester.test_pipeline import ( - TosaPipelineBI, - TosaPipelineMI, -) - -aten_op = "torch.ops.aten.pad.default" -exir_op = "executorch_exir_dialects_edge__ops_aten_pad_default" - -input_t1 = Tuple[torch.Tensor] # Input x - -test_data_suite = { - "4dim_last1dim": (torch.rand(1, 1, 16, 16), (1, 1, 0, 0, 0, 0, 0, 0), 1), - "4dim_last2dim": (torch.rand(1, 1, 16, 16), (1, 0, 1, 0, 0, 0, 0, 0), 2), - "4dim_last3dim": (torch.rand(1, 1, 16, 16), (1, 1, 0, 2, 0, 2, 0, 0), 3), - "4dim_last4dim": (torch.rand(1, 1, 16, 16), (1, 0, 1, 1, 0, 2, 0, 2), 4), - "3dim_last1dim": (torch.rand(1, 1, 16), (1, 1, 0, 0, 0, 0), 1), - "3dim_last2dim": (torch.rand(1, 1, 16), (1, 0, 1, 1, 0, 0), 2), - "3dim_last3dim": (torch.rand(1, 1, 16), (1, 0, 1, 0, 1, 1), 3), - "2dim_last1dim": (torch.rand(1, 1, 16), (1, 1, 0, 0), 1), - "2dim_last2dim": (torch.rand(1, 1, 16), (1, 0, 1, 1), 2), -} - - -"""Tests conv + pad.""" - - -class ConstantPadND(torch.nn.Module): - def __init__(self, pad: Tuple, value: float | None = None): - super().__init__() - self.dim = len(pad) // 2 - self.value = value - in_channels = 1 - # Only apply conv2d when the input dim = 4. - if self.dim == 4: - in_channels += pad[-3] + pad[-4] - - self.conv2d = nn.Conv2d( - in_channels=in_channels, - out_channels=3, - kernel_size=3, - bias=True, - stride=(2, 2), - padding=0, - ) - - in_channels = 3 - in_channels += pad[-3] + pad[-4] - self.conv2d_1 = nn.Conv2d( - in_channels=in_channels, - out_channels=3, - kernel_size=3, - bias=True, - padding="same", - ) - - nonzero_idx = len(pad) - for i in range(0, len(pad), 2): - if pad[i] + pad[i + 1] == 0: - nonzero_idx = i - break - self.pad = pad[:nonzero_idx] - self.relu = nn.ReLU() - self.sigmoid = nn.Sigmoid() - - def forward(self, x: torch.Tensor): - x = F.pad(x, pad=self.pad, mode="constant", value=self.value) - if self.dim == 4: - x = self.conv2d(x) - x = self.relu(x) - - x = F.pad(x, pad=self.pad, mode="constant", value=self.value) - if self.dim == 4: - x = self.conv2d_1(x) - x = self.sigmoid(x) - return x - - -@common.parametrize("test_data", test_data_suite) -def test_constant_pad_nd_tosa_MI(test_data: Tuple): - test_data, padding, value = test_data - pipeline = TosaPipelineMI[input_t1]( - ConstantPadND(padding, value), - (test_data,), - aten_op, - exir_op, - ) - pipeline.run() - - -@common.parametrize("test_data", test_data_suite) -def test_constant_pad_nd_tosa_BI(test_data: Tuple): - test_data, padding, value = test_data - pipeline = TosaPipelineBI[input_t1]( - ConstantPadND(padding, value), - (test_data,), - aten_op, - exir_op, - ) - pipeline.run() diff --git a/backends/arm/test/ops/test_cos.py b/backends/arm/test/ops/test_cos.py deleted file mode 100644 index 21902f51192..00000000000 --- a/backends/arm/test/ops/test_cos.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright 2025 Arm Limited and/or its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Tuple - -import torch - -from executorch.backends.arm.test import common, conftest -from executorch.backends.arm.test.tester.test_pipeline import ( - EthosU55PipelineBI, - EthosU85PipelineBI, - TosaPipelineBI, - TosaPipelineMI, -) - -aten_op = "torch.ops.aten.cos.default" -input_t1 = Tuple[torch.Tensor] # Input x - -test_data_suite = { - # (test_name, test_data) - "zeros": torch.zeros(10, 10, 10, 10), - "ones": torch.ones(10, 10, 10), - "rand": torch.rand(10, 10) - 0.5, - "randn_pos": torch.randn(10) + 10, - "randn_neg": torch.randn(10) - 10, - "ramp": torch.arange(-16, 16, 0.2), -} - - -class Cos(torch.nn.Module): - - def forward(self, x: torch.Tensor): - return torch.cos(x) - - -@common.parametrize("test_data", test_data_suite) -def test_cos_tosa_MI(test_data: Tuple): - pipeline = TosaPipelineMI[input_t1]( - Cos(), - (test_data,), - aten_op, - exir_op=[], - ) - if conftest.get_option("tosa_version") == "1.0": - pipeline.run() - - -@common.parametrize("test_data", test_data_suite) -def test_cos_tosa_BI(test_data: Tuple): - pipeline = TosaPipelineBI[input_t1]( - Cos(), - (test_data,), - aten_op, - exir_op=[], - ) - pipeline.run() - - -@common.parametrize("test_data", test_data_suite) -def test_cos_tosa_u55_BI(test_data: Tuple): - pipeline = EthosU55PipelineBI[input_t1]( - Cos(), - (test_data,), - aten_op, - exir_ops=[], - run_on_fvp=False, - ) - pipeline.run() - - -@common.parametrize("test_data", test_data_suite) -def test_cos_tosa_u85_BI(test_data: Tuple): - pipeline = EthosU85PipelineBI[input_t1]( - Cos(), - (test_data,), - aten_op, - exir_ops=[], - run_on_fvp=False, - ) - pipeline.run() diff --git a/backends/arm/test/ops/test_sin.py b/backends/arm/test/ops/test_sin.py deleted file mode 100644 index 7f1f9f569af..00000000000 --- a/backends/arm/test/ops/test_sin.py +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright 2025 Arm Limited and/or its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Tuple - -import torch - -from executorch.backends.arm.test import common, conftest -from executorch.backends.arm.test.tester.test_pipeline import ( - EthosU55PipelineBI, - EthosU85PipelineBI, - TosaPipelineBI, - TosaPipelineMI, -) - -aten_op = "torch.ops.aten.sin.default" -input_t1 = Tuple[torch.Tensor] # Input x - -test_data_suite = { - # (test_name, test_data) - "zeros": torch.zeros(10, 10, 10, 10), - "ones": torch.ones(10, 10, 10), - "rand": torch.rand(10, 10) - 0.5, - "randn_pos": torch.randn(10) + 10, - "randn_neg": torch.randn(10) - 10, - "ramp": torch.arange(-16, 16, 0.2), -} - - -class Sin(torch.nn.Module): - - def forward(self, x: torch.Tensor): - return torch.sin(x) - - -@common.parametrize("test_data", test_data_suite) -def test_sin_tosa_MI(test_data: Tuple): - pipeline = TosaPipelineMI[input_t1]( - Sin(), - (test_data,), - aten_op, - exir_op=[], - ) - if conftest.get_option("tosa_version") == "1.0": - pipeline.run() - - -@common.parametrize("test_data", test_data_suite) -def test_sin_tosa_BI(test_data: Tuple): - pipeline = TosaPipelineBI[input_t1]( - Sin(), - (test_data,), - aten_op, - exir_op=[], - ) - pipeline.run() - - -@common.parametrize("test_data", test_data_suite) -def test_sin_tosa_u55_BI(test_data: Tuple): - pipeline = EthosU55PipelineBI[input_t1]( - Sin(), - (test_data,), - aten_op, - exir_ops=[], - run_on_fvp=False, - ) - pipeline.run() - - -@common.parametrize("test_data", test_data_suite) -def test_sin_tosa_u85_BI(test_data: Tuple): - pipeline = EthosU85PipelineBI[input_t1]( - Sin(), - (test_data,), - aten_op, - exir_ops=[], - run_on_fvp=False, - ) - pipeline.run()