diff --git a/backends/arm/_passes/insert_table_ops.py b/backends/arm/_passes/insert_table_ops.py index 3732c8a367b..02510600d82 100644 --- a/backends/arm/_passes/insert_table_ops.py +++ b/backends/arm/_passes/insert_table_ops.py @@ -41,6 +41,7 @@ class TableOps: # Targets that follow a straigtforward one-to-one mapping to their table op unary_table_ops: Dict[EdgeOpOverload, Callable[[torch.Tensor], torch.Tensor]] = { exir_ops.edge.aten.ceil.default: torch.ceil, + exir_ops.edge.aten.erf.default: torch.erf, exir_ops.edge.aten.exp.default: torch.exp, exir_ops.edge.aten.floor.default: torch.floor, exir_ops.edge.aten.log.default: torch.log, diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 0e5d7ecc958..95e351cfee3 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -166,6 +166,7 @@ def is_node_supported( exir_ops.edge.aten.div.Tensor, exir_ops.edge.aten.eq.Tensor, exir_ops.edge.aten.eq.Scalar, + exir_ops.edge.aten.erf.default, exir_ops.edge.aten.exp.default, exir_ops.edge.aten.log.default, exir_ops.edge.aten.linear.default, diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index 2a610536f3e..b62e8940ed2 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -19,6 +19,7 @@ op_constant_pad_nd, op_conv2d, op_eq, + op_erf, op_exp, op_full, op_ge, diff --git a/backends/arm/operators/op_erf.py b/backends/arm/operators/op_erf.py new file mode 100644 index 00000000000..d0dc2af572f --- /dev/null +++ b/backends/arm/operators/op_erf.py @@ -0,0 +1,44 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# pyre-unsafe +from typing import List + +import serializer.tosa_serializer as ts # type: ignore +import torch.fx +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_specification import TosaSpecification +from serializer.tosa_serializer import TosaOp + + +@register_node_visitor +class ERFVisitor_080_MI(NodeVisitor): + target = "aten.erf.default" + + # BI case handled by op_table + tosa_specs = [TosaSpecification.create_from_string("TOSA-0.80+MI")] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + if not (inputs[0].dtype == output.dtype): + raise ValueError( + "All inputs and output need same dtype." + f"Got {inputs[0].dtype=}, {output.dtype=}" + ) + if not (inputs[0].dtype == ts.DType.FP32): + raise ValueError("All inputs need to be FP32." f"Got {inputs[0].dtype=}") + # MI lowering + tosa_graph.addOperator(TosaOp.Op().ERF, [inputs[0].name], [output.name]) diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index baca13029a3..e32cb6e32a1 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -164,6 +164,7 @@ def _match_pattern( _one_to_one = [ torch.ops.aten.abs.default, torch.ops.aten.ceil.default, + torch.ops.aten.erf.default, torch.ops.aten.exp.default, torch.ops.aten.floor.default, torch.ops.aten.log.default, diff --git a/backends/arm/test/ops/test_erf.py b/backends/arm/test/ops/test_erf.py new file mode 100644 index 00000000000..d452be7cae1 --- /dev/null +++ b/backends/arm/test/ops/test_erf.py @@ -0,0 +1,63 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineBI, + EthosU85PipelineBI, + TosaPipelineBI, + TosaPipelineMI, +) + +aten_op = "torch.ops.aten.erf.default" +exir_op = "executorch_exir_dialects_edge__ops_aten_erf_default" +input_t1 = Tuple[torch.Tensor] # Input x + + +class Erf(torch.nn.Module): + def forward(self, x: torch.Tensor): + return torch.erf(x) + + test_data: dict[str, input_t1] = { + "zeros": (torch.zeros(1, 10, 10, 10),), + "ones": (torch.ones(10, 10, 10),), + "rand": ((torch.rand(10, 10) - 0.5),), + "randn_pos": ((torch.randn(1, 4, 4, 4) + 10),), + "randn_neg": ((torch.randn(1, 4, 4, 4) - 10),), + "ramp": (torch.arange(-16, 16, 0.2),), + } + + +@common.parametrize("test_data", Erf.test_data) +def test_erf_tosa_MI(test_data: input_t1): + pipeline = TosaPipelineMI[input_t1](Erf(), test_data, aten_op, exir_op) + pipeline.run() + + +@common.parametrize("test_data", Erf.test_data) +def test_erf_tosa_BI(test_data: input_t1): + pipeline = TosaPipelineBI[input_t1](Erf(), test_data, aten_op, exir_op) + pipeline.run() + + +@common.parametrize("test_data", Erf.test_data) +@common.XfailIfNoCorstone300 +def test_erf_u55_BI(test_data: input_t1): + pipeline = EthosU55PipelineBI[input_t1]( + Erf(), test_data, aten_op, exir_op, run_on_fvp=True + ) + pipeline.run() + + +@common.parametrize("test_data", Erf.test_data) +@common.XfailIfNoCorstone320 +def test_erf_u85_BI(test_data: input_t1): + pipeline = EthosU85PipelineBI[input_t1]( + Erf(), test_data, aten_op, exir_op, run_on_fvp=True + ) + pipeline.run()