diff --git a/backends/arm/_passes/insert_table_ops.py b/backends/arm/_passes/insert_table_ops.py index a5f66829da5..6bb1ce7dce1 100644 --- a/backends/arm/_passes/insert_table_ops.py +++ b/backends/arm/_passes/insert_table_ops.py @@ -48,6 +48,8 @@ class TableOps: exir_ops.edge.aten.reciprocal.default: torch.reciprocal, exir_ops.edge.aten.rsqrt.default: torch.rsqrt, exir_ops.edge.aten.sigmoid.default: torch.sigmoid, + exir_ops.edge.aten.cos.default: torch.cos, + exir_ops.edge.aten.sin.default: torch.sin, exir_ops.edge.aten.tanh.default: torch.tanh, exir_ops.edge.aten.hardsigmoid.default: torch.nn.functional.hardsigmoid, exir_ops.edge.aten.hardswish.default: torch.nn.functional.hardswish, diff --git a/backends/arm/operator_support/__init__.py b/backends/arm/operator_support/__init__.py index bd54c3e1f85..42ab9bca3cd 100644 --- a/backends/arm/operator_support/__init__.py +++ b/backends/arm/operator_support/__init__.py @@ -12,6 +12,7 @@ pool_2d_support, reduce_sum_support, right_shift_support, + sin_cos_support, slice_copy_support, to_copy_support, tosa_supported_operators, diff --git a/backends/arm/operator_support/sin_cos_support.py b/backends/arm/operator_support/sin_cos_support.py new file mode 100644 index 00000000000..9dd63e8258d --- /dev/null +++ b/backends/arm/operator_support/sin_cos_support.py @@ -0,0 +1,32 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + + +import torch.fx as fx +from executorch.backends.arm.operator_support.tosa_supported_operators import ( + register_tosa_support_check, + SupportedTOSAOperatorCheck, +) +from executorch.backends.arm.tosa_specification import TosaSpecification +from executorch.exir.dialects._ops import ops as exir_ops + + +@register_tosa_support_check +class SinCosSupported(SupportedTOSAOperatorCheck): + targets = [ + exir_ops.edge.aten.cos.default, + exir_ops.edge.aten.sin.default, + ] + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80+BI"), + TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), + ] + + def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): + return True diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 5de90bda252..d8d1711e725 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -23,7 +23,11 @@ EthosU55NotSupported, EthosU55TransposeCheck, ) -from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification +from executorch.backends.arm.tosa_specification import ( + Tosa_0_80, + Tosa_1_00, + TosaSpecification, +) from executorch.exir import ExportedProgram from executorch.exir.backend.utils import WhyNoPartitionReporter from executorch.exir.dialects._ops import ops as exir_ops @@ -124,7 +128,9 @@ def tosa_support_factory( if not tosa_spec.support_float(): negative_checks.append(NeedsDecompositionCheck(reporter)) negative_checks.append(CheckProperQuantization(reporter)) - if isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset: + if (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset) or ( + isinstance(tosa_spec, Tosa_1_00) and "u55" in tosa_spec.extensions + ): negative_checks.append(EthosU55NotSupported(reporter)) negative_checks.append(EthosU55DtypeSupport(reporter)) negative_checks.append(EthosU55TransposeCheck(reporter)) diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index 3ee243779e6..e496fe74d54 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -18,6 +18,7 @@ op_clamp, op_constant_pad_nd, op_conv2d, + op_cos, op_eq, op_erf, op_exp, @@ -38,6 +39,7 @@ op_rshift_tensor, op_rsqrt, op_sigmoid, + op_sin, op_slice, op_sub, op_sum, diff --git a/backends/arm/operators/op_any.py b/backends/arm/operators/op_any.py index b65ebb2ac5d..efb5b0b72b0 100644 --- a/backends/arm/operators/op_any.py +++ b/backends/arm/operators/op_any.py @@ -4,9 +4,8 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe -from typing import cast, List +from typing import Any, cast, List -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( # type: ignore NodeVisitor, register_node_visitor, @@ -17,16 +16,19 @@ @register_node_visitor -class AnyVisitor(NodeVisitor): +class AnyVisitor_0_80(NodeVisitor): target = "aten.any.dim" + tosa_specs = NodeVisitor.tosa_specs_0_80 + def define_node( self, node: Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore if not (inputs[0].dtype == output.dtype): raise ValueError( @@ -50,3 +52,42 @@ def define_node( tosa_graph.addOperator( ts.TosaOp.Op().REDUCE_ANY, [inputs[0].name], [output.name], attr ) + + +@register_node_visitor +class AnyVisitor(NodeVisitor): + target = "aten.any.dim" + + tosa_specs = NodeVisitor.tosa_specs_1_00 + + def define_node( + self, + node: Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + import serializer.tosa_serializer as ts + + if not (inputs[0].dtype == output.dtype): + raise ValueError( + "All inputs and outputs need same dtype." + f"Got {ts.DTypeNames[inputs[0].dtype]=}, {ts.DTypeNames[output.dtype]=}." + ) + if not (inputs[0].dtype == ts.DType.BOOL): + raise ValueError("All inputs need to be BOOL." f"Got {inputs[0].dtype=}") + + input_shape = list(inputs[0].shape) + dim = cast(int, inputs[1].number) % len( + input_shape + ) # process the negative index + keep_dim = cast(bool, inputs[2].number if len(inputs) > 2 else False) + if not keep_dim: + raise ValueError("This case should be handled by ConvertAnyDimDimsPass") + + attr = ts.TosaSerializerAttribute() + attr.ReduceAnyAttribute(inputs[0].dim_order.index(dim)) + + tosa_graph.addOperator( + ts.TosaOp.Op().REDUCE_ANY, [inputs[0].name], [output.name], attr + ) diff --git a/backends/arm/operators/op_avg_pool2d.py b/backends/arm/operators/op_avg_pool2d.py index bdd3425fda5..73a6713633a 100644 --- a/backends/arm/operators/op_avg_pool2d.py +++ b/backends/arm/operators/op_avg_pool2d.py @@ -4,12 +4,10 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe -from typing import List +from typing import Any, List import torch -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( get_input_qparams, get_output_qparams, @@ -36,14 +34,16 @@ def __init__(self, *args): def _build_generic_avgpool2d( self, node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, input_zp: int, output_zp: int, - accumulator_type: ts.DType, + accumulator_type: Any, ) -> None: + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + input_tensor = inputs[0] kernel_size_list = inputs[1].special stride_size_list = inputs[2].special @@ -79,10 +79,12 @@ def _build_generic_avgpool2d( def define_node( self, node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + input_tensor = inputs[0] assert input_tensor.dtype == ts.DType.INT8 @@ -110,10 +112,135 @@ class AvgPool2dVisitor_0_80_MI(AvgPool2dVisitor_0_80_BI): def define_node( self, node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + + assert ( + inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.FP32 + ), "Only FP32 and INT8 supported" + + if inputs[0].dtype == ts.DType.INT8: + super().define_node(node, tosa_graph, inputs, output) + + if inputs[0].dtype == ts.DType.FP32: + accumulator_type = ts.DType.FP32 + # Initilize zero point to zero. + input_zp = 0 + output_zp = 0 + + self._build_generic_avgpool2d( + node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type + ) + + +@register_node_visitor +class AvgPool2dVisitor(NodeVisitor): + target = "aten.avg_pool2d.default" + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-1.0+INT"), + ] + + def __init__(self, *args): + super().__init__(*args) + + def _build_generic_avgpool2d( + self, + node: torch.fx.Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + input_zp: int, + output_zp: int, + accumulator_type: Any, + ) -> None: + + import serializer.tosa_serializer as ts # type: ignore + + input_tensor = inputs[0] + kernel_size_list = inputs[1].special + stride_size_list = inputs[2].special + + try: + pad_size_list = inputs[3].special + pad_size_list = [ + pad_size_list[0], + pad_size_list[0], + pad_size_list[1], + pad_size_list[1], + ] + except IndexError: + pad_size_list = [0, 0, 0, 0] + + attr = ts.TosaSerializerAttribute() + attr.AvgPool2dAttribute( + kernel=kernel_size_list, + stride=stride_size_list, + pad=pad_size_list, + acc_type=accumulator_type, + ) + input_zp_tensor = tosa_graph.addConst( + shape=[1], dtype=output.dtype, vals=[input_zp] + ) + output_zp_tensor = tosa_graph.addConst( + shape=[1], dtype=output.dtype, vals=[output_zp] + ) + + tosa_graph.addOperator( + ts.TosaOp.Op().AVG_POOL2D, + [input_tensor.name, input_zp_tensor.name, output_zp_tensor.name], + [output.name], + attr, + ) + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + import serializer.tosa_serializer as ts # type: ignore + + input_tensor = inputs[0] + assert input_tensor.dtype == ts.DType.INT8 + + accumulator_type = ts.DType.INT32 + + input_qargs = get_input_qparams(node) + input_zp = input_qargs[0].zp + + output_qargs = get_output_qparams(node) + output_zp = output_qargs[0].zp + + self._build_generic_avgpool2d( + node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type + ) + + +@register_node_visitor +class AvgPool2dVisitor_FP(AvgPool2dVisitor): + target = "aten.avg_pool2d.default" + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-1.0+FP"), + ] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + import serializer.tosa_serializer as ts # type: ignore + assert ( inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.FP32 ), "Only FP32 and INT8 supported" diff --git a/backends/arm/operators/op_cat.py b/backends/arm/operators/op_cat.py index 6b1710301b1..bb77ba77940 100644 --- a/backends/arm/operators/op_cat.py +++ b/backends/arm/operators/op_cat.py @@ -5,9 +5,8 @@ # pyre-unsafe -from typing import List +from typing import Any, List -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -17,19 +16,22 @@ @register_node_visitor -class CatVisitor(NodeVisitor): +class CatVisitor_0_80(NodeVisitor): target = "aten.cat.default" + tosa_specs = NodeVisitor.tosa_specs_0_80 + def __init__(self, *args): super().__init__(*args) def define_node( self, node: Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore tensors = inputs[0].special dim = 0 if len(inputs) < 2 else inputs[1].number @@ -46,3 +48,38 @@ def define_node( [output.name], attr, ) + + +@register_node_visitor +class CatVisitor(NodeVisitor): + target = "aten.cat.default" + + tosa_specs = NodeVisitor.tosa_specs_1_00 + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + import serializer.tosa_serializer as ts + + tensors = inputs[0].special + dim = 0 if len(inputs) < 2 else inputs[1].number + rank = len(output.shape) + dim = (dim + rank) % rank + dim = output.dim_order.index(dim) + + attr = ts.TosaSerializerAttribute() + attr.ConcatAttribute(dim) + + tosa_graph.addOperator( + ts.TosaOp.Op().CONCAT, + [tensor.name for tensor in tensors], + [output.name], + attr, + ) diff --git a/backends/arm/operators/op_constant_pad_nd.py b/backends/arm/operators/op_constant_pad_nd.py index b2c31df96ab..75cdc0b0fc4 100644 --- a/backends/arm/operators/op_constant_pad_nd.py +++ b/backends/arm/operators/op_constant_pad_nd.py @@ -5,12 +5,10 @@ # pyre-unsafe -from typing import List +from typing import Any, List import torch -import tosa_tools.v0_80.serializer.tosa_serializer as ts - from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( get_input_qparams, ) @@ -19,20 +17,27 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_specification import TosaSpecification @register_node_visitor -class ConstantPadNDVisitor(NodeVisitor): +class ConstantPadNDVisitor_0_80(NodeVisitor): target = "aten.constant_pad_nd.default" + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80+BI"), + TosaSpecification.create_from_string("TOSA-0.80+MI"), + ] + def define_node( self, node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + import tosa_tools.v0_80.serializer.tosa_serializer as ts if inputs[0].dtype == ts.DType.INT8: input_qparams = get_input_qparams(node) @@ -74,3 +79,72 @@ def define_node( tosa_graph.addOperator( ts.TosaOp.Op().PAD, [inputs[0].name], [output.name], attr ) + + +@register_node_visitor +class ConstantPadNDVisitor(NodeVisitor): + + target = "aten.constant_pad_nd.default" + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), + ] + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + + import serializer.tosa_serializer as ts # type: ignore + + if inputs[0].dtype == ts.DType.INT8: + input_qparams = get_input_qparams(node) + qargs = input_qparams[0] + pad_const_val = qargs.quantize_value(inputs[2].number).item() + pad_const_dtype = ts.DType.INT8 + else: + pad_const_val = inputs[2].number + pad_const_dtype = inputs[0].dtype + + rank = len(output.shape) + # Each dim needs 2 padding values. For example, to pad the last dimension, the pad has the form + # (padding_left, padding_right); to pad the last two dimensions, the pad has the form + # (padding_left, padding_right, padding_top, padding_bottom), and so on. For PyTorch NCHW format, the padding + # values are in the reverse order. So, firstly we need to reverse the input padding parameters. + input_pad = sum( + [ + [inputs[1].special[i], inputs[1].special[i + 1]] + for i in range(0, len(inputs[1].special), 2) + ][::-1], + [], + ) + # Then, add dummy zeros to make sure that both input_pad and output_pad has the same size. + input_pad = [0] * (rank * 2 - len(inputs[1].special)) + input_pad + # For PyTorch NCHW format, dim order is [0,...,rank-1] + input_dim_order = list(range(rank)) + output_pad = [0] * rank * 2 + + # Map input padding parameters into output padding parameters. TOSA is NHWC format. + for input_dim_idx, input_dim in enumerate(input_dim_order): + output_dim_idx = output.dim_order.index(input_dim) + output_pad[output_dim_idx * 2 : (output_dim_idx + 1) * 2] = input_pad[ + input_dim_idx * 2 : (input_dim_idx + 1) * 2 + ] + + padding = tosa_graph.addConst( + shape=[len(output_pad)], dtype=ts.DType.SHAPE, vals=output_pad + ) + + pad_const = tosa_graph.addConst( + shape=[1], dtype=pad_const_dtype, vals=[pad_const_val] + ) + + tosa_graph.addOperator( + ts.TosaOp.Op().PAD, + [inputs[0].name, padding.name, pad_const.name], + [output.name], + ) diff --git a/backends/arm/operators/op_cos.py b/backends/arm/operators/op_cos.py new file mode 100644 index 00000000000..1fee25511ce --- /dev/null +++ b/backends/arm/operators/op_cos.py @@ -0,0 +1,46 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +from typing import List + +import serializer.tosa_serializer as ts # type: ignore +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_specification import TosaSpecification +from torch.fx import Node + + +@register_node_visitor +class CosVisitor(NodeVisitor): + target = "aten.cos.default" + + # INT case should be handled by op_table + tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + if len(node.all_input_nodes) != 1: + raise ValueError( + f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" + ) + if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32: + raise ValueError( + f"Input and output for {self.target} need to be FP32, got input_dtype: " + f"{inputs[0].dtype} and output_dtype: {output.dtype}" + ) + + tosa_graph.addOperator(ts.TosaOp.Op().COS, [inputs[0].name], [output.name]) diff --git a/backends/arm/operators/op_erf.py b/backends/arm/operators/op_erf.py index 01243716129..e174069ee77 100644 --- a/backends/arm/operators/op_erf.py +++ b/backends/arm/operators/op_erf.py @@ -3,11 +3,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # pyre-unsafe -from typing import List +from typing import Any, List import torch.fx -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore -import tosa_tools.v0_80.tosa.Op as TosaOp # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -29,10 +27,42 @@ def __init__(self, *args): def define_node( self, node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + + if not (inputs[0].dtype == output.dtype): + raise ValueError( + "All inputs and output need same dtype." + f"Got {inputs[0].dtype=}, {output.dtype=}" + ) + if not (inputs[0].dtype == ts.DType.FP32): + raise ValueError("All inputs need to be FP32." f"Got {inputs[0].dtype=}") + # MI lowering + tosa_graph.addOperator(ts.TosaOp.Op().ERF, [inputs[0].name], [output.name]) + + +@register_node_visitor +class ERFVisitor(NodeVisitor): + target = "aten.erf.default" + + # INT case handled by op_table + tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + import serializer.tosa_serializer as ts + if not (inputs[0].dtype == output.dtype): raise ValueError( "All inputs and output need same dtype." @@ -40,5 +70,6 @@ def define_node( ) if not (inputs[0].dtype == ts.DType.FP32): raise ValueError("All inputs need to be FP32." f"Got {inputs[0].dtype=}") + # MI lowering - tosa_graph.addOperator(TosaOp.Op().ERF, [inputs[0].name], [output.name]) + tosa_graph.addOperator(ts.TosaOp.Op().ERF, [inputs[0].name], [output.name]) diff --git a/backends/arm/operators/op_exp.py b/backends/arm/operators/op_exp.py index ca067b3b8be..60cc727d149 100644 --- a/backends/arm/operators/op_exp.py +++ b/backends/arm/operators/op_exp.py @@ -4,9 +4,8 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe -from typing import List +from typing import Any, List -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -29,10 +28,43 @@ def __init__(self, *args): def define_node( self, node: Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + + if len(node.all_input_nodes) != 1: + raise ValueError( + f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" + ) + if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32: + raise ValueError( + f"Input and output for {self.target} need to be FP32, got input dtype: " + f"{inputs[0].dtype} and output dtype: {output.dtype}" + ) + + tosa_graph.addOperator(ts.TosaOp.Op().EXP, [inputs[0].name], [output.name]) + + +@register_node_visitor +class ExpVisitor(NodeVisitor): + target = "aten.exp.default" + + # BI case should be handled by op_table + tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + import serializer.tosa_serializer as ts if len(node.all_input_nodes) != 1: raise ValueError( diff --git a/backends/arm/operators/op_log.py b/backends/arm/operators/op_log.py index 9942cbf4702..b08bbcec003 100644 --- a/backends/arm/operators/op_log.py +++ b/backends/arm/operators/op_log.py @@ -4,9 +4,8 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe -from typing import List +from typing import Any, List -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -17,7 +16,7 @@ @register_node_visitor -class LogVisitor(NodeVisitor): +class LogVisitor_0_80_MI(NodeVisitor): target = "aten.log.default" # BI case should be handled by op_table @@ -29,10 +28,44 @@ def __init__(self, *args): def define_node( self, node: Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + + if len(node.all_input_nodes) != 1: + raise ValueError( + f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" + ) + if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32: + raise ValueError( + f"Input and output for {self.target} need to be FP32, got input_dtype: " + f"{inputs[0].dtype} and output_dtype: {output.dtype}" + ) + + tosa_graph.addOperator(ts.TosaOp.Op().LOG, [inputs[0].name], [output.name]) + + +@register_node_visitor +class LogVisitor(NodeVisitor): + target = "aten.log.default" + + # INT case should be handled by op_table + tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + import serializer.tosa_serializer as ts + if len(node.all_input_nodes) != 1: raise ValueError( f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" diff --git a/backends/arm/operators/op_max_pool2d.py b/backends/arm/operators/op_max_pool2d.py index fcf2636977d..928262aefc5 100644 --- a/backends/arm/operators/op_max_pool2d.py +++ b/backends/arm/operators/op_max_pool2d.py @@ -4,12 +4,10 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe -from typing import List +from typing import Any, List import torch -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( get_input_qparams, get_output_qparams, @@ -19,22 +17,29 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_specification import TosaSpecification @register_node_visitor -class MaxPool2dVisitor(NodeVisitor): +class MaxPool2dVisitor_0_80(NodeVisitor): target = "aten.max_pool2d.default" + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80+BI"), + TosaSpecification.create_from_string("TOSA-0.80+MI"), + ] + def __init__(self, *args): super().__init__(*args) def define_node( self, node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore input_tensor = inputs[0] kernel_size = inputs[1].special @@ -80,3 +85,53 @@ def define_node( [output.name], attr, ) + + +@register_node_visitor +class MaxPool2dVisitor(NodeVisitor): + target = "aten.max_pool2d.default" + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), + ] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + + import serializer.tosa_serializer as ts # type: ignore + + input_tensor = inputs[0] + kernel_size = inputs[1].special + stride = inputs[2].special + + try: + pad_size_list = inputs[3].special + pad_size_list = [ + pad_size_list[0], + pad_size_list[0], + pad_size_list[1], + pad_size_list[1], + ] + except IndexError: + pad_size_list = [0, 0, 0, 0] + + attr = ts.TosaSerializerAttribute() + attr.MaxPool2dAttribute( + kernel=kernel_size, stride=stride, pad=pad_size_list, nan_mode=1 + ) + + tosa_graph.addOperator( + ts.TosaOp.Op().MAX_POOL2D, + [input_tensor.name], + [output.name], + attr, + ) diff --git a/backends/arm/operators/op_permute.py b/backends/arm/operators/op_permute.py index c92a008a281..f3ea8b00961 100644 --- a/backends/arm/operators/op_permute.py +++ b/backends/arm/operators/op_permute.py @@ -5,11 +5,10 @@ # pyre-unsafe -from typing import List +from typing import Any, List import torch -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -88,20 +87,61 @@ def transform_permutation_vector(permutation_vector: list[int], dim_order: list[ return permutation_vector +@register_node_visitor +class PermuteVisitor_0_80(NodeVisitor): + target = "aten.permute_copy.default" + + tosa_specs = NodeVisitor.tosa_specs_0_80 + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + + # The permutation vector describes a permutation P in default Pytorch dim_order. + # For rank 4, the default dim_order NCHW. + # E.g. (2,3,0,1) -> permute (n,c,h,w) to (w,c,n,h) + permutation_vector = inputs[1].special + + if output.dim_order != tuple(range(len(output.dim_order))): + # the permutation vector can't be used directly if we are not in NCHW dim_order. + # Transform to dim_order. + permutation_vector = transform_permutation_vector( + permutation_vector, output.dim_order + ) + + attr = ts.TosaSerializerAttribute() + attr.TransposeAttribute(permutation_vector) + tosa_graph.addOperator( + ts.TosaOp.Op().TRANSPOSE, [inputs[0].name], [output.name], attr + ) + + @register_node_visitor class PermuteVisitor(NodeVisitor): target = "aten.permute_copy.default" + tosa_specs = NodeVisitor.tosa_specs_1_00 + def __init__(self, *args): super().__init__(*args) def define_node( self, node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + import serializer.tosa_serializer as ts + # The permutation vector describes a permutation P in default Pytorch dim_order. # For rank 4, the default dim_order NCHW. # E.g. (2,3,0,1) -> permute (n,c,h,w) to (w,c,n,h) diff --git a/backends/arm/operators/op_pow.py b/backends/arm/operators/op_pow.py index d3b92feff12..781fce3c79f 100644 --- a/backends/arm/operators/op_pow.py +++ b/backends/arm/operators/op_pow.py @@ -5,9 +5,8 @@ # pyre-unsafe -from typing import List +from typing import Any, List -import tosa_tools.v0_80.serializer.tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -31,10 +30,53 @@ def __init__(self, *args): def define_node( self, node: Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + + if not (inputs[0].dtype == inputs[1].dtype == output.dtype): + raise ValueError( + "All inputs and outputs need same dtype." + f"Got {inputs[0].dtype=}, {inputs[1].dtype=}, {output.dtype=}" + ) + if inputs[0].dtype not in [ts.DType.FP32, ts.DType.FP16]: + raise ValueError( + f"All inputs need to be FP32 or FP16. Got {inputs[0].dtype}" + ) + + tosa_graph.addOperator( + ts.TosaOp.Op().POW, + [ + inputs[0].name, + inputs[1].name, + ], + [output.name], + None, + ) + + +@register_node_visitor +class PowVisitor(NodeVisitor): + target = "aten.pow.Tensor_Tensor" + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-1.0+FP"), + ] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + import serializer.tosa_serializer as ts + if not (inputs[0].dtype == inputs[1].dtype == output.dtype): raise ValueError( "All inputs and outputs need same dtype." diff --git a/backends/arm/operators/op_reciprocal.py b/backends/arm/operators/op_reciprocal.py index c75fb99977e..7d1ee951993 100644 --- a/backends/arm/operators/op_reciprocal.py +++ b/backends/arm/operators/op_reciprocal.py @@ -4,11 +4,10 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe -from typing import List +from typing import Any, List import torch -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -30,10 +29,46 @@ def __init__(self, *args): def define_node( self, node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + + if len(node.all_input_nodes) != 1: + raise ValueError( + f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" + ) + if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32: + raise ValueError( + f"Input and output for {self.target} need to be FP32, got " + f"{inputs[0].dtype=} and {output.dtype=}" + ) + + tosa_graph.addOperator( + ts.TosaOp.Op().RECIPROCAL, [inputs[0].name], [output.name] + ) + + +@register_node_visitor +class ReciprocalVisitor(NodeVisitor): + target = "aten.reciprocal.default" + + # INT case should be handled by op_table + tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + import serializer.tosa_serializer as ts + if len(node.all_input_nodes) != 1: raise ValueError( f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" diff --git a/backends/arm/operators/op_rshift_tensor.py b/backends/arm/operators/op_rshift_tensor.py index 125f5493a29..375dd76ba8d 100644 --- a/backends/arm/operators/op_rshift_tensor.py +++ b/backends/arm/operators/op_rshift_tensor.py @@ -5,30 +5,32 @@ # pyre-unsafe -from typing import List +from typing import Any, List import torch -import tosa_tools.v0_80.serializer.tosa_serializer as ts from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_specification import Tosa_0_80 +from executorch.backends.arm.tosa_specification import Tosa_0_80, Tosa_1_00 @register_node_visitor -class RshiftVisitor(NodeVisitor): +class RshiftVisitor_0_80(NodeVisitor): target = "aten.bitwise_right_shift.Tensor" + tosa_specs = NodeVisitor.tosa_specs_0_80 + def define_node( self, node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore attr = ts.TosaSerializerAttribute() round = False @@ -44,3 +46,34 @@ def define_node( [output.name], attr, ) + + +@register_node_visitor +class RshiftVisitor(NodeVisitor): + target = "aten.bitwise_right_shift.Tensor" + + tosa_specs = NodeVisitor.tosa_specs_1_00 + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + import serializer.tosa_serializer as ts + + attr = ts.TosaSerializerAttribute() + round = False + if isinstance(self.tosa_spec, Tosa_1_00) and "u55" in self.tosa_spec.extensions: + # U55 only supports INT32 and round == True + # TODO MLETORCH-525 Emulate round == False with different decomposition + round = True + attr.ArithmeticRightShiftAttribute(round=round) + + tosa_graph.addOperator( + ts.TosaOp.Op().ARITHMETIC_RIGHT_SHIFT, + [inputs[0].name, inputs[1].name], + [output.name], + attr, + ) diff --git a/backends/arm/operators/op_rsqrt.py b/backends/arm/operators/op_rsqrt.py index e3937f8c44a..784c4b4d257 100644 --- a/backends/arm/operators/op_rsqrt.py +++ b/backends/arm/operators/op_rsqrt.py @@ -4,11 +4,10 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe -from typing import List +from typing import Any, List import torch -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -30,10 +29,44 @@ def __init__(self, *args): def define_node( self, node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + + if len(node.all_input_nodes) != 1: + raise ValueError( + f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" + ) + if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32: + raise ValueError( + f"Input and output for {self.target} need to be FP32, got " + f"{inputs[0].dtype=} and {output.dtype=}" + ) + + tosa_graph.addOperator(ts.TosaOp.Op().RSQRT, [inputs[0].name], [output.name]) + + +@register_node_visitor +class RsqrtVisitor(NodeVisitor): + target = "aten.rsqrt.default" + + # INT case should be handled by op_table + tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + import serializer.tosa_serializer as ts + if len(node.all_input_nodes) != 1: raise ValueError( f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" diff --git a/backends/arm/operators/op_sigmoid.py b/backends/arm/operators/op_sigmoid.py index 9a002036fee..a43e9ae798f 100644 --- a/backends/arm/operators/op_sigmoid.py +++ b/backends/arm/operators/op_sigmoid.py @@ -4,9 +4,8 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe -from typing import List +from typing import Any, List -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -29,10 +28,43 @@ def __init__(self, *args): def define_node( self, node: Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + + if len(node.all_input_nodes) != 1: + raise ValueError( + f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" + ) + if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32: + raise ValueError( + f"Input and output for {self.target} need to be FP32, got input_dtype: " + f"{inputs[0].dtype} and output_dtype: {output.dtype}" + ) + + tosa_graph.addOperator(ts.TosaOp.Op().SIGMOID, [inputs[0].name], [output.name]) + + +@register_node_visitor +class SigmoidVisitor(NodeVisitor): + target = "aten.sigmoid.default" + + # INT case should be handled by op_table + tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + import serializer.tosa_serializer as ts if len(node.all_input_nodes) != 1: raise ValueError( diff --git a/backends/arm/operators/op_sin.py b/backends/arm/operators/op_sin.py new file mode 100644 index 00000000000..ee444c38f37 --- /dev/null +++ b/backends/arm/operators/op_sin.py @@ -0,0 +1,46 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +from typing import List + +import serializer.tosa_serializer as ts # type: ignore +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_specification import TosaSpecification +from torch.fx import Node + + +@register_node_visitor +class SinVisitor(NodeVisitor): + target = "aten.sin.default" + + # INT case should be handled by op_table + tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + if len(node.all_input_nodes) != 1: + raise ValueError( + f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" + ) + if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32: + raise ValueError( + f"Input and output for {self.target} need to be FP32, got input_dtype: " + f"{inputs[0].dtype} and output_dtype: {output.dtype}" + ) + + tosa_graph.addOperator(ts.TosaOp.Op().SIN, [inputs[0].name], [output.name]) diff --git a/backends/arm/operators/op_tanh.py b/backends/arm/operators/op_tanh.py index 51cf1ee786b..01af36c4d37 100644 --- a/backends/arm/operators/op_tanh.py +++ b/backends/arm/operators/op_tanh.py @@ -4,9 +4,8 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe -from typing import List +from typing import Any, List -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -17,7 +16,7 @@ @register_node_visitor -class TanhVisitor_080_MI(NodeVisitor): +class TanhVisitor_0_80_MI(NodeVisitor): target = "aten.tanh.default" # BI case should be handled by op_table @@ -29,10 +28,44 @@ def __init__(self, *args): def define_node( self, node: Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + + if len(node.all_input_nodes) != 1: + raise ValueError( + f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" + ) + if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32: + raise ValueError( + f"Input and output for {self.target} need to be FP32, got input_dtype: " + f"{inputs[0].dtype} and output_dtype: {output.dtype}" + ) + + tosa_graph.addOperator(ts.TosaOp.Op().TANH, [inputs[0].name], [output.name]) + + +@register_node_visitor +class TanhVisitor(NodeVisitor): + target = "aten.tanh.default" + + # INT case should be handled by op_table + tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + import serializer.tosa_serializer as ts + if len(node.all_input_nodes) != 1: raise ValueError( f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" diff --git a/backends/arm/operators/op_upsample_nearest2d.py b/backends/arm/operators/op_upsample_nearest2d.py index 23d24b78339..63694b715f0 100644 --- a/backends/arm/operators/op_upsample_nearest2d.py +++ b/backends/arm/operators/op_upsample_nearest2d.py @@ -4,11 +4,10 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe -from typing import List +from typing import Any, List import torch -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -20,19 +19,23 @@ @register_node_visitor -class UpsampleNearest2dVisitor(NodeVisitor): +class UpsampleNearest2dVisitor_0_80(NodeVisitor): target = "aten.upsample_nearest2d.vec" + tosa_specs = NodeVisitor.tosa_specs_0_80 + def __init__(self, *args): super().__init__(*args) def define_node( self, node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + assert ( inputs[0].shape is not None and output.shape is not None ), "Only static shapes are supported" @@ -67,3 +70,74 @@ def in_int16_range(x): tosa_graph.addOperator( ts.TosaOp.Op().RESIZE, [inputs[0].name], [output.name], attr ) + + +@register_node_visitor +class UpsampleNearest2dVisitor(NodeVisitor): + target = "aten.upsample_nearest2d.vec" + + tosa_specs = NodeVisitor.tosa_specs_1_00 + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + import serializer.tosa_serializer as ts + + assert ( + inputs[0].shape is not None and output.shape is not None + ), "Only static shapes are supported" + + # tosa_shape output is NHWC, take HW + input_size_yx = torch.tensor( + tosa_shape(inputs[0].shape, inputs[0].dim_order)[1:3] + ) + # Ignore scale and size parameters, directly use the output size as + # we only support static shapes currently + output_size_yx = torch.tensor(tosa_shape(output.shape, output.dim_order)[1:3]) + + scale_n_yx, scale_d_yx, offset_yx, border_yx = get_resize_parameters( + input_size_yx, output_size_yx, ResizeMode.NEAREST, align_corners=True + ) + + def in_int16_range(x): + return torch.all(x >= -(2**15)) and torch.all(x <= 2**15 - 1) + + assert in_int16_range(scale_n_yx) + assert in_int16_range(scale_d_yx) + assert in_int16_range(border_yx) + + scales = [scale_n_yx[0], scale_d_yx[0], scale_n_yx[1], scale_d_yx[1]] + scales_tensor = tosa_graph.addConst( + [len(scales)], ts.DType.SHAPE, scales, node.name + "_scales" + ) + offset = offset_yx.tolist() + offset_tensor = tosa_graph.addConst( + [len(offset)], ts.DType.SHAPE, offset, node.name + "_offset" + ) + border = border_yx.tolist() + border_tensor = tosa_graph.addConst( + [len(border)], ts.DType.SHAPE, border, node.name + "_border" + ) + attr = ts.TosaSerializerAttribute() + attr.ResizeAttribute( + mode=ResizeMode.NEAREST, + ) + + tosa_graph.addOperator( + ts.TosaOp.Op().RESIZE, + [ + inputs[0].name, + scales_tensor.name, + offset_tensor.name, + border_tensor.name, + ], + [output.name], + attr, + ) diff --git a/backends/arm/operators/op_where.py b/backends/arm/operators/op_where.py index ba2469e74e1..b58fda1c399 100644 --- a/backends/arm/operators/op_where.py +++ b/backends/arm/operators/op_where.py @@ -3,10 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import List, Sequence - -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore -import tosa_tools.v0_80.tosa.Op as TosaOp # type: ignore +from typing import Any, List, Sequence from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, @@ -17,54 +14,148 @@ from torch.fx import Node -def _add_node_to_tosa_graph( - tosa_graph: ts.TosaSerializer, - inputs: List[TosaArg], - output: TosaArg, - supported_dtypes: Sequence, -) -> None: - if len(inputs) != 3: - raise ValueError(f"aten.where.self expects 3 arguments, got {len(inputs)}") - - if inputs[0].dtype is not ts.DType.BOOL: - raise ValueError("Input 0 needs to have dtype BOOL") - if inputs[1].dtype != inputs[2].dtype: - raise ValueError( - "Non-condition tensors must have same data type, got " - f"{inputs[1].dtype} and {inputs[2].dtype}" - ) - for input_ in inputs[1:]: - if input_.dtype not in supported_dtypes: +@register_node_visitor +class WhereVisitor_0_80_BI(NodeVisitor): + target = "aten.where.self" + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80+BI"), + ] + + def __init__(self, *args): + super().__init__(*args) + + def _add_node_to_tosa_graph( + self, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + supported_dtypes: Sequence, + ) -> None: + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + + if len(inputs) != 3: + raise ValueError(f"aten.where.self expects 3 arguments, got {len(inputs)}") + + if inputs[0].dtype is not ts.DType.BOOL: + raise ValueError("Input 0 needs to have dtype BOOL") + if inputs[1].dtype != inputs[2].dtype: raise ValueError( - f"Input needs to be of torch dtype {supported_dtypes}, got {input_.dtype}" + "Non-condition tensors must have same data type, got " + f"{inputs[1].dtype} and {inputs[2].dtype}" ) + for input_ in inputs[1:]: + if input_.dtype not in supported_dtypes: + raise ValueError( + f"Input needs to be of torch dtype {supported_dtypes}, got {input_.dtype}" + ) + + tosa_graph.addOperator( + ts.TosaOp.Op().SELECT, + [inputs[0].name, inputs[1].name, inputs[2].name], + [output.name], + None, + ) - tosa_graph.addOperator( - TosaOp.Op().SELECT, - [inputs[0].name, inputs[1].name, inputs[2].name], - [output.name], - None, - ) + def define_node( + self, + node: Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + + bi_supported_dtypes = [ + ts.DType.INT8, + ts.DType.INT16, + ts.DType.INT32, + ts.DType.BOOL, + ] + self._add_node_to_tosa_graph(tosa_graph, inputs, output, bi_supported_dtypes) @register_node_visitor -class WhereVisitor_080_BI(NodeVisitor): +class WhereVisitor_0_80_MI(WhereVisitor_0_80_BI): + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80+MI"), + ] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + + mi_supported_dtypes = [ + ts.DType.FP16, + ts.DType.FP32, + ts.DType.INT8, + ts.DType.INT16, + ts.DType.INT32, + ts.DType.BOOL, + ] + self._add_node_to_tosa_graph(tosa_graph, inputs, output, mi_supported_dtypes) + + +@register_node_visitor +class WhereVisitor_INT(NodeVisitor): target = "aten.where.self" tosa_specs = [ - TosaSpecification.create_from_string("TOSA-0.80+BI"), + TosaSpecification.create_from_string("TOSA-1.0+INT"), ] def __init__(self, *args): super().__init__(*args) + def _add_node_to_tosa_graph( + self, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + supported_dtypes: Sequence, + ) -> None: + import serializer.tosa_serializer as ts + + if len(inputs) != 3: + raise ValueError(f"aten.where.self expects 3 arguments, got {len(inputs)}") + + if inputs[0].dtype is not ts.DType.BOOL: + raise ValueError("Input 0 needs to have dtype BOOL") + if inputs[1].dtype != inputs[2].dtype: + raise ValueError( + "Non-condition tensors must have same data type, got " + f"{inputs[1].dtype} and {inputs[2].dtype}" + ) + for input_ in inputs[1:]: + if input_.dtype not in supported_dtypes: + raise ValueError( + f"Input needs to be of torch dtype {supported_dtypes}, got {input_.dtype}" + ) + + tosa_graph.addOperator( + ts.TosaOp.Op().SELECT, + [inputs[0].name, inputs[1].name, inputs[2].name], + [output.name], + None, + ) + def define_node( self, node: Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + import serializer.tosa_serializer as ts bi_supported_dtypes = [ ts.DType.INT8, @@ -72,14 +163,14 @@ def define_node( ts.DType.INT32, ts.DType.BOOL, ] - _add_node_to_tosa_graph(tosa_graph, inputs, output, bi_supported_dtypes) + self._add_node_to_tosa_graph(tosa_graph, inputs, output, bi_supported_dtypes) @register_node_visitor -class WhereVisitor_080_MI(WhereVisitor_080_BI): +class WhereVisitor_FP(WhereVisitor_INT): tosa_specs = [ - TosaSpecification.create_from_string("TOSA-0.80+MI"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), ] def __init__(self, *args): @@ -88,10 +179,12 @@ def __init__(self, *args): def define_node( self, node: Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + import serializer.tosa_serializer as ts + mi_supported_dtypes = [ ts.DType.FP16, ts.DType.FP32, @@ -100,4 +193,4 @@ def define_node( ts.DType.INT32, ts.DType.BOOL, ] - _add_node_to_tosa_graph(tosa_graph, inputs, output, mi_supported_dtypes) + self._add_node_to_tosa_graph(tosa_graph, inputs, output, mi_supported_dtypes) diff --git a/backends/arm/operators/ops_binary.py b/backends/arm/operators/ops_binary.py index a17da41f767..425007bab3c 100644 --- a/backends/arm/operators/ops_binary.py +++ b/backends/arm/operators/ops_binary.py @@ -5,13 +5,11 @@ # pyre-unsafe -from typing import List +from typing import Any, List import torch import torch.fx -import tosa_tools.v0_80.serializer.tosa_serializer as ts - from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -19,19 +17,50 @@ from executorch.backends.arm.tosa_mapping import TosaArg +def binary_operator_factory_0_80(bw_target: str, tosa_op): + """Creates and registers NodeVisitors for operators that have two inputs and map directly to a TOSA op.""" + + class BinaryOperator_0_80(NodeVisitor): + target = bw_target + tosa_specs = NodeVisitor.tosa_specs_0_80 + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore # noqa: F401 + + if not (inputs[0].dtype == inputs[1].dtype == output.dtype): + raise ValueError( + "All inputs and outputs need same dtype." + f"Got {inputs[0].dtype=}, {inputs[1].dtype=}, {output.dtype=}." + ) + + tosa_graph.addOperator( + tosa_op, [inputs[0].name, inputs[1].name], [output.name] + ) + + register_node_visitor(BinaryOperator_0_80) + + def binary_operator_factory(bw_target: str, tosa_op): """Creates and registers NodeVisitors for operators that have two inputs and map directly to a TOSA op.""" class BinaryOperator(NodeVisitor): target = bw_target + tosa_specs = NodeVisitor.tosa_specs_1_00 def define_node( self, node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + import serializer.tosa_serializer as ts # type: ignore # noqa: F401 if not (inputs[0].dtype == inputs[1].dtype == output.dtype): raise ValueError( @@ -46,6 +75,20 @@ def define_node( register_node_visitor(BinaryOperator) +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + +binary_operator_factory_0_80("aten.bitwise_and.Tensor", ts.TosaOp.Op().BITWISE_AND) +binary_operator_factory_0_80("aten.bitwise_xor.Tensor", ts.TosaOp.Op().BITWISE_XOR) +binary_operator_factory_0_80("aten.bitwise_or.Tensor", ts.TosaOp.Op().BITWISE_OR) +binary_operator_factory_0_80("aten.logical_and.default", ts.TosaOp.Op().LOGICAL_AND) +binary_operator_factory_0_80("aten.logical_xor.default", ts.TosaOp.Op().LOGICAL_XOR) +binary_operator_factory_0_80("aten.logical_or.default", ts.TosaOp.Op().LOGICAL_OR) +binary_operator_factory_0_80( + "aten.bitwise_left_shift.Tensor", ts.TosaOp.Op().LOGICAL_LEFT_SHIFT +) + +import serializer.tosa_serializer as ts # type: ignore + binary_operator_factory("aten.bitwise_and.Tensor", ts.TosaOp.Op().BITWISE_AND) binary_operator_factory("aten.bitwise_xor.Tensor", ts.TosaOp.Op().BITWISE_XOR) binary_operator_factory("aten.bitwise_or.Tensor", ts.TosaOp.Op().BITWISE_OR) diff --git a/backends/arm/operators/ops_unary.py b/backends/arm/operators/ops_unary.py index 3f713e086e6..3bb2be16585 100644 --- a/backends/arm/operators/ops_unary.py +++ b/backends/arm/operators/ops_unary.py @@ -4,11 +4,10 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe -from typing import List +from typing import Any, List import torch.fx -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -17,6 +16,44 @@ from executorch.backends.arm.tosa_mapping import TosaArg +def unary_operator_factory_0_80(unary_target: str, tosa_op): + "Creates and registers NodeVisitors for operations that have one input and map directly into a TOSA op." + + # Some TOSA unary operators only support float + fp_only_ops = ["aten.floor.default"] + + class UnaryOperator_0_80(NodeVisitor): + target = unary_target + tosa_specs = NodeVisitor.tosa_specs_0_80 + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore # noqa: F401 + + if not (inputs[0].dtype == output.dtype): + raise ValueError( + "All inputs and output need same dtype." + f"Got {inputs[0].dtype=}, {output.dtype=}" + ) + + if self.target in fp_only_ops and not (inputs[0].dtype == ts.DType.FP32): + raise ValueError( + "All inputs need to be FP32." f"Got {inputs[0].dtype=}" + ) + + tosa_graph.addOperator(tosa_op, [inputs[0].name], [output.name]) + + register_node_visitor(UnaryOperator_0_80) + + def unary_operator_factory(unary_target: str, tosa_op): "Creates and registers NodeVisitors for operations that have one input and map directly into a TOSA op." @@ -25,6 +62,7 @@ def unary_operator_factory(unary_target: str, tosa_op): class UnaryOperator(NodeVisitor): target = unary_target + tosa_specs = NodeVisitor.tosa_specs_1_00 def __init__(self, *args): super().__init__(*args) @@ -32,10 +70,11 @@ def __init__(self, *args): def define_node( self, node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, + tosa_graph: Any, inputs: List[TosaArg], output: TosaArg, ) -> None: + import serializer.tosa_serializer as ts # type: ignore # noqa: F401 if not (inputs[0].dtype == output.dtype): raise ValueError( @@ -53,6 +92,14 @@ def define_node( register_node_visitor(UnaryOperator) +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + +unary_operator_factory_0_80("aten.ceil.default", ts.TosaOp.Op().CEIL) +unary_operator_factory_0_80("aten.floor.default", ts.TosaOp.Op().FLOOR) +unary_operator_factory_0_80("aten.logical_not.default", ts.TosaOp.Op().LOGICAL_NOT) + +import serializer.tosa_serializer as ts # type: ignore + unary_operator_factory("aten.ceil.default", ts.TosaOp.Op().CEIL) unary_operator_factory("aten.floor.default", ts.TosaOp.Op().FLOOR) unary_operator_factory("aten.logical_not.default", ts.TosaOp.Op().LOGICAL_NOT) diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index ad866fa9d13..8c081e7c00a 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -171,6 +171,8 @@ def _match_pattern( torch.ops.aten.reciprocal.default, torch.ops.aten.rsqrt.default, torch.ops.aten.sigmoid.default, + torch.ops.aten.cos.default, + torch.ops.aten.sin.default, torch.ops.aten.tanh.default, torch.ops.aten.sum.dim_IntList, torch.ops.aten.hardsigmoid.default, diff --git a/backends/arm/test/misc/test_multiple_delegates.py b/backends/arm/test/misc/test_multiple_delegates.py index 9103e8ca899..ab768d273c6 100644 --- a/backends/arm/test/misc/test_multiple_delegates.py +++ b/backends/arm/test/misc/test_multiple_delegates.py @@ -20,7 +20,7 @@ def get_inputs(self): def forward(self, x: torch.Tensor, y: torch.Tensor): z = x + y - s = torch.sin(z) + s = torch.tan(z) return s * z def test_tosa_MI(self): diff --git a/backends/arm/test/ops/test_constant_pad_nd.py b/backends/arm/test/ops/test_constant_pad_nd.py index 3dfd28640eb..9a19f6fbf5f 100644 --- a/backends/arm/test/ops/test_constant_pad_nd.py +++ b/backends/arm/test/ops/test_constant_pad_nd.py @@ -2,143 +2,74 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - # # Test the pad_constant_nd op which pads the input tensor at specific dimension(s). # -import unittest from typing import Tuple import torch -import torch.nn as nn import torch.nn.functional as F from executorch.backends.arm.test import common -from executorch.backends.arm.test.tester.arm_tester import ArmTester -from parameterized import parameterized - -test_data_suite = [ - ("4dim_last1dim", torch.rand(1, 1, 16, 16), (1, 1, 0, 0, 0, 0, 0, 0), 1), - ("4dim_last2dim", torch.rand(1, 1, 16, 16), (1, 0, 1, 0, 0, 0, 0, 0), 2), - ("4dim_last3dim", torch.rand(1, 1, 16, 16), (1, 1, 0, 2, 0, 2, 0, 0), 3), - ("4dim_last4dim", torch.rand(1, 1, 16, 16), (1, 0, 1, 1, 0, 2, 0, 2), 4), - ("3dim_last1dim", torch.rand(1, 1, 16), (1, 1, 0, 0, 0, 0), 1), - ("3dim_last2dim", torch.rand(1, 1, 16), (1, 0, 1, 1, 0, 0), 2), - ("3dim_last3dim", torch.rand(1, 1, 16), (1, 0, 1, 0, 1, 1), 3), - ("2dim_last1dim", torch.rand(1, 1, 16), (1, 1, 0, 0), 1), - ("2dim_last2dim", torch.rand(1, 1, 16), (1, 0, 1, 1), 2), -] - - -class TestConstantPadND(unittest.TestCase): - """Tests pad.""" - - class ConstantPadND(torch.nn.Module): - def __init__(self, pad: Tuple, value: float | None = None): - super().__init__() - self.dim = len(pad) // 2 - self.value = value - in_channels = 1 - # Only apply conv2d when the input dim = 4. - if self.dim == 4: - in_channels += pad[-3] + pad[-4] - - self.conv2d = nn.Conv2d( - in_channels=in_channels, - out_channels=3, - kernel_size=3, - bias=True, - stride=(2, 2), - padding=0, - ) +from executorch.backends.arm.test.tester.test_pipeline import ( + TosaPipelineBI, + TosaPipelineMI, +) - in_channels = 3 - in_channels += pad[-3] + pad[-4] - self.conv2d_1 = nn.Conv2d( - in_channels=in_channels, - out_channels=3, - kernel_size=3, - bias=True, - padding="same", - ) +aten_op = "torch.ops.aten.pad.default" +exir_op = "executorch_exir_dialects_edge__ops_aten_pad_default" +input_t1 = Tuple[torch.Tensor] # Input x +test_data_suite = { + "4dim_last1dim": (torch.rand(1, 1, 16, 16), (1, 1, 0, 0, 0, 0, 0, 0), 1), + "4dim_last2dim": (torch.rand(1, 1, 16, 16), (1, 0, 1, 0, 0, 0, 0, 0), 2), + "4dim_last3dim": (torch.rand(1, 1, 16, 16), (1, 1, 0, 2, 0, 2, 0, 0), 3), + "4dim_last4dim": (torch.rand(1, 1, 16, 16), (1, 0, 1, 1, 0, 2, 0, 2), 4), + "3dim_last1dim": (torch.rand(1, 1, 16), (1, 1, 0, 0, 0, 0), 1), + "3dim_last2dim": (torch.rand(1, 1, 16), (1, 0, 1, 1, 0, 0), 2), + "3dim_last3dim": (torch.rand(1, 1, 16), (1, 0, 1, 0, 1, 1), 3), + "2dim_last1dim": (torch.rand(1, 1, 16), (1, 1, 0, 0), 1), + "2dim_last2dim": (torch.rand(1, 1, 16), (1, 0, 1, 1), 2), +} +"""Tests pad.""" - nonzero_idx = len(pad) - for i in range(0, len(pad), 2): - if pad[i] + pad[i + 1] == 0: - nonzero_idx = i - break - self.pad = pad[:nonzero_idx] - self.relu = nn.ReLU() - self.sigmoid = nn.Sigmoid() - def forward(self, x: torch.Tensor): - x = F.pad(x, pad=self.pad, mode="constant", value=self.value) - if self.dim == 4: - x = self.conv2d(x) - x = self.relu(x) +class ConstantPadND(torch.nn.Module): + def __init__(self, pad: Tuple, value: float | None = None): + super().__init__() + self.value = value + nonzero_idx = len(pad) + for i in range(0, len(pad), 2): + if pad[i] + pad[i + 1] == 0: + nonzero_idx = i + break + self.pad = pad[:nonzero_idx] - x = F.pad(x, pad=self.pad, mode="constant", value=self.value) - if self.dim == 4: - x = self.conv2d_1(x) - x = self.sigmoid(x) - return x + def forward(self, x: torch.Tensor): + x = F.pad(x, pad=self.pad, mode="constant", value=self.value) + return x - def _test_constant_pad_nd_tosa_MI_pipeline( - self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] - ): - ( - ArmTester( - module, - example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), - ) - .export() - .check_count({"torch.ops.aten.pad.default": 2}) - .to_edge() - .partition() - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - .to_executorch() - .run_method_and_compare_outputs(inputs=test_data) - ) - def _test_constant_pad_nd_tosa_BI_pipeline( - self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] - ): - ( - ArmTester( - module, - example_inputs=test_data, - compile_spec=common.get_tosa_compile_spec("TOSA-0.80+BI"), - ) - .quantize() - .export() - .check_count({"torch.ops.aten.pad.default": 2}) - .to_edge() - .partition() - .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) - .to_executorch() - .run_method_and_compare_outputs(inputs=test_data, qtol=1) - ) +@common.parametrize( + "test_data", + test_data_suite, +) +def test_constant_pad_nd_tosa_MI(test_data: Tuple): + test_data, padding, value = test_data + pipeline = TosaPipelineMI[input_t1]( + ConstantPadND(padding, value), + (test_data,), + aten_op, + exir_op, + ) + pipeline.run() - @parameterized.expand(test_data_suite) - def test_constant_pad_nd_tosa_MI( - self, - test_name: str, - test_data: torch.Tensor, - padding: Tuple, - value: float | None = None, - ): - self._test_constant_pad_nd_tosa_MI_pipeline( - self.ConstantPadND(padding, value), (test_data,) - ) - @parameterized.expand(test_data_suite) - def test_constant_pad_nd_tosa_BI( - self, - test_name: str, - test_data: torch.Tensor, - padding: Tuple, - value: float | None = None, - ): - self._test_constant_pad_nd_tosa_BI_pipeline( - self.ConstantPadND(padding, value), (test_data,) - ) +@common.parametrize("test_data", test_data_suite) +def test_constant_pad_nd_tosa_BI(test_data: Tuple): + test_data, padding, value = test_data + pipeline = TosaPipelineBI[input_t1]( + ConstantPadND(padding, value), + (test_data,), + aten_op, + exir_op, + ) + pipeline.run() diff --git a/backends/arm/test/ops/test_conv_constant_pad_nd.py b/backends/arm/test/ops/test_conv_constant_pad_nd.py new file mode 100644 index 00000000000..026b5d1dc4d --- /dev/null +++ b/backends/arm/test/ops/test_conv_constant_pad_nd.py @@ -0,0 +1,114 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# +# Test the pad_constant_nd op which pads the input tensor at specific dimension(s). +# + +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + TosaPipelineBI, + TosaPipelineMI, +) + +aten_op = "torch.ops.aten.pad.default" +exir_op = "executorch_exir_dialects_edge__ops_aten_pad_default" + +input_t1 = Tuple[torch.Tensor] # Input x + +test_data_suite = { + "4dim_last1dim": (torch.rand(1, 1, 16, 16), (1, 1, 0, 0, 0, 0, 0, 0), 1), + "4dim_last2dim": (torch.rand(1, 1, 16, 16), (1, 0, 1, 0, 0, 0, 0, 0), 2), + "4dim_last3dim": (torch.rand(1, 1, 16, 16), (1, 1, 0, 2, 0, 2, 0, 0), 3), + "4dim_last4dim": (torch.rand(1, 1, 16, 16), (1, 0, 1, 1, 0, 2, 0, 2), 4), + "3dim_last1dim": (torch.rand(1, 1, 16), (1, 1, 0, 0, 0, 0), 1), + "3dim_last2dim": (torch.rand(1, 1, 16), (1, 0, 1, 1, 0, 0), 2), + "3dim_last3dim": (torch.rand(1, 1, 16), (1, 0, 1, 0, 1, 1), 3), + "2dim_last1dim": (torch.rand(1, 1, 16), (1, 1, 0, 0), 1), + "2dim_last2dim": (torch.rand(1, 1, 16), (1, 0, 1, 1), 2), +} + + +"""Tests conv + pad.""" + + +class ConstantPadND(torch.nn.Module): + def __init__(self, pad: Tuple, value: float | None = None): + super().__init__() + self.dim = len(pad) // 2 + self.value = value + in_channels = 1 + # Only apply conv2d when the input dim = 4. + if self.dim == 4: + in_channels += pad[-3] + pad[-4] + + self.conv2d = nn.Conv2d( + in_channels=in_channels, + out_channels=3, + kernel_size=3, + bias=True, + stride=(2, 2), + padding=0, + ) + + in_channels = 3 + in_channels += pad[-3] + pad[-4] + self.conv2d_1 = nn.Conv2d( + in_channels=in_channels, + out_channels=3, + kernel_size=3, + bias=True, + padding="same", + ) + + nonzero_idx = len(pad) + for i in range(0, len(pad), 2): + if pad[i] + pad[i + 1] == 0: + nonzero_idx = i + break + self.pad = pad[:nonzero_idx] + self.relu = nn.ReLU() + self.sigmoid = nn.Sigmoid() + + def forward(self, x: torch.Tensor): + x = F.pad(x, pad=self.pad, mode="constant", value=self.value) + if self.dim == 4: + x = self.conv2d(x) + x = self.relu(x) + + x = F.pad(x, pad=self.pad, mode="constant", value=self.value) + if self.dim == 4: + x = self.conv2d_1(x) + x = self.sigmoid(x) + return x + + +@common.parametrize("test_data", test_data_suite) +def test_constant_pad_nd_tosa_MI(test_data: Tuple): + test_data, padding, value = test_data + pipeline = TosaPipelineMI[input_t1]( + ConstantPadND(padding, value), + (test_data,), + aten_op, + exir_op, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +def test_constant_pad_nd_tosa_BI(test_data: Tuple): + test_data, padding, value = test_data + pipeline = TosaPipelineBI[input_t1]( + ConstantPadND(padding, value), + (test_data,), + aten_op, + exir_op, + ) + pipeline.run() diff --git a/backends/arm/test/ops/test_cos.py b/backends/arm/test/ops/test_cos.py new file mode 100644 index 00000000000..21902f51192 --- /dev/null +++ b/backends/arm/test/ops/test_cos.py @@ -0,0 +1,83 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch + +from executorch.backends.arm.test import common, conftest +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineBI, + EthosU85PipelineBI, + TosaPipelineBI, + TosaPipelineMI, +) + +aten_op = "torch.ops.aten.cos.default" +input_t1 = Tuple[torch.Tensor] # Input x + +test_data_suite = { + # (test_name, test_data) + "zeros": torch.zeros(10, 10, 10, 10), + "ones": torch.ones(10, 10, 10), + "rand": torch.rand(10, 10) - 0.5, + "randn_pos": torch.randn(10) + 10, + "randn_neg": torch.randn(10) - 10, + "ramp": torch.arange(-16, 16, 0.2), +} + + +class Cos(torch.nn.Module): + + def forward(self, x: torch.Tensor): + return torch.cos(x) + + +@common.parametrize("test_data", test_data_suite) +def test_cos_tosa_MI(test_data: Tuple): + pipeline = TosaPipelineMI[input_t1]( + Cos(), + (test_data,), + aten_op, + exir_op=[], + ) + if conftest.get_option("tosa_version") == "1.0": + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +def test_cos_tosa_BI(test_data: Tuple): + pipeline = TosaPipelineBI[input_t1]( + Cos(), + (test_data,), + aten_op, + exir_op=[], + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +def test_cos_tosa_u55_BI(test_data: Tuple): + pipeline = EthosU55PipelineBI[input_t1]( + Cos(), + (test_data,), + aten_op, + exir_ops=[], + run_on_fvp=False, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +def test_cos_tosa_u85_BI(test_data: Tuple): + pipeline = EthosU85PipelineBI[input_t1]( + Cos(), + (test_data,), + aten_op, + exir_ops=[], + run_on_fvp=False, + ) + pipeline.run() diff --git a/backends/arm/test/ops/test_sin.py b/backends/arm/test/ops/test_sin.py new file mode 100644 index 00000000000..7f1f9f569af --- /dev/null +++ b/backends/arm/test/ops/test_sin.py @@ -0,0 +1,83 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch + +from executorch.backends.arm.test import common, conftest +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineBI, + EthosU85PipelineBI, + TosaPipelineBI, + TosaPipelineMI, +) + +aten_op = "torch.ops.aten.sin.default" +input_t1 = Tuple[torch.Tensor] # Input x + +test_data_suite = { + # (test_name, test_data) + "zeros": torch.zeros(10, 10, 10, 10), + "ones": torch.ones(10, 10, 10), + "rand": torch.rand(10, 10) - 0.5, + "randn_pos": torch.randn(10) + 10, + "randn_neg": torch.randn(10) - 10, + "ramp": torch.arange(-16, 16, 0.2), +} + + +class Sin(torch.nn.Module): + + def forward(self, x: torch.Tensor): + return torch.sin(x) + + +@common.parametrize("test_data", test_data_suite) +def test_sin_tosa_MI(test_data: Tuple): + pipeline = TosaPipelineMI[input_t1]( + Sin(), + (test_data,), + aten_op, + exir_op=[], + ) + if conftest.get_option("tosa_version") == "1.0": + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +def test_sin_tosa_BI(test_data: Tuple): + pipeline = TosaPipelineBI[input_t1]( + Sin(), + (test_data,), + aten_op, + exir_op=[], + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +def test_sin_tosa_u55_BI(test_data: Tuple): + pipeline = EthosU55PipelineBI[input_t1]( + Sin(), + (test_data,), + aten_op, + exir_ops=[], + run_on_fvp=False, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +def test_sin_tosa_u85_BI(test_data: Tuple): + pipeline = EthosU85PipelineBI[input_t1]( + Sin(), + (test_data,), + aten_op, + exir_ops=[], + run_on_fvp=False, + ) + pipeline.run()