Skip to content

Arm backend: Add ERF operator #9836

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/_passes/insert_table_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class TableOps:
# Targets that follow a straigtforward one-to-one mapping to their table op
unary_table_ops: Dict[EdgeOpOverload, Callable[[torch.Tensor], torch.Tensor]] = {
exir_ops.edge.aten.ceil.default: torch.ceil,
exir_ops.edge.aten.erf.default: torch.erf,
exir_ops.edge.aten.exp.default: torch.exp,
exir_ops.edge.aten.floor.default: torch.floor,
exir_ops.edge.aten.log.default: torch.log,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def is_node_supported(
exir_ops.edge.aten.div.Tensor,
exir_ops.edge.aten.eq.Tensor,
exir_ops.edge.aten.eq.Scalar,
exir_ops.edge.aten.erf.default,
exir_ops.edge.aten.exp.default,
exir_ops.edge.aten.log.default,
exir_ops.edge.aten.linear.default,
Expand Down
1 change: 1 addition & 0 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
op_constant_pad_nd,
op_conv2d,
op_eq,
op_erf,
op_exp,
op_full,
op_ge,
Expand Down
44 changes: 44 additions & 0 deletions backends/arm/operators/op_erf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
from typing import List

import serializer.tosa_serializer as ts # type: ignore
import torch.fx
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_specification import TosaSpecification
from serializer.tosa_serializer import TosaOp


@register_node_visitor
class ERFVisitor_080_MI(NodeVisitor):
target = "aten.erf.default"

# BI case handled by op_table
tosa_specs = [TosaSpecification.create_from_string("TOSA-0.80+MI")]

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
if not (inputs[0].dtype == output.dtype):
raise ValueError(
"All inputs and output need same dtype."
f"Got {inputs[0].dtype=}, {output.dtype=}"
)
if not (inputs[0].dtype == ts.DType.FP32):
raise ValueError("All inputs need to be FP32." f"Got {inputs[0].dtype=}")
# MI lowering
tosa_graph.addOperator(TosaOp.Op().ERF, [inputs[0].name], [output.name])
1 change: 1 addition & 0 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def _match_pattern(
_one_to_one = [
torch.ops.aten.abs.default,
torch.ops.aten.ceil.default,
torch.ops.aten.erf.default,
torch.ops.aten.exp.default,
torch.ops.aten.floor.default,
torch.ops.aten.log.default,
Expand Down
63 changes: 63 additions & 0 deletions backends/arm/test/ops/test_erf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

import torch
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.test_pipeline import (
EthosU55PipelineBI,
EthosU85PipelineBI,
TosaPipelineBI,
TosaPipelineMI,
)

aten_op = "torch.ops.aten.erf.default"
exir_op = "executorch_exir_dialects_edge__ops_aten_erf_default"
input_t1 = Tuple[torch.Tensor] # Input x


class Erf(torch.nn.Module):
def forward(self, x: torch.Tensor):
return torch.erf(x)

test_data: dict[str, input_t1] = {
"zeros": (torch.zeros(1, 10, 10, 10),),
"ones": (torch.ones(10, 10, 10),),
"rand": ((torch.rand(10, 10) - 0.5),),
"randn_pos": ((torch.randn(1, 4, 4, 4) + 10),),
"randn_neg": ((torch.randn(1, 4, 4, 4) - 10),),
"ramp": (torch.arange(-16, 16, 0.2),),
}


@common.parametrize("test_data", Erf.test_data)
def test_erf_tosa_MI(test_data: input_t1):
pipeline = TosaPipelineMI[input_t1](Erf(), test_data, aten_op, exir_op)
pipeline.run()


@common.parametrize("test_data", Erf.test_data)
def test_erf_tosa_BI(test_data: input_t1):
pipeline = TosaPipelineBI[input_t1](Erf(), test_data, aten_op, exir_op)
pipeline.run()


@common.parametrize("test_data", Erf.test_data)
@common.XfailIfNoCorstone300
def test_erf_u55_BI(test_data: input_t1):
pipeline = EthosU55PipelineBI[input_t1](
Erf(), test_data, aten_op, exir_op, run_on_fvp=True
)
pipeline.run()


@common.parametrize("test_data", Erf.test_data)
@common.XfailIfNoCorstone320
def test_erf_u85_BI(test_data: input_t1):
pipeline = EthosU85PipelineBI[input_t1](
Erf(), test_data, aten_op, exir_op, run_on_fvp=True
)
pipeline.run()
Loading