Skip to content

Arm backend: Update more node visitors to support TOSA 1.0 #10425

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Apr 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
ce7b9cb
Arm backend: Support for reduce_any for TOSA 1.0
oscarandersson8218 Apr 10, 2025
0c6892a
Arm backend: Support for concat for TOSA 1.0
oscarandersson8218 Apr 10, 2025
57b193e
Arm backend: Support for ERF for TOSA 1.0
oscarandersson8218 Apr 10, 2025
03b4997
Arm backend: Support for exp for TOSA 1.0
oscarandersson8218 Apr 10, 2025
b55ad4b
Arm backend: Support for log for TOSA 1.0
oscarandersson8218 Apr 10, 2025
273fa6a
Arm backend: Support for permute for TOSA 1.0
oscarandersson8218 Apr 10, 2025
a0a758d
Arm backend: Support for pow for TOSA 1.0
oscarandersson8218 Apr 10, 2025
0384ed2
Arm backend: Support for reciprocal for TOSA 1.0
oscarandersson8218 Apr 10, 2025
9256398
Arm backend: Support for right_shift for TOSA 1.0
oscarandersson8218 Apr 10, 2025
fad0e19
Arm backend: Support for rsqrt for TOSA 1.0
oscarandersson8218 Apr 10, 2025
ca9f091
Arm backend: Support for sigmoid for TOSA 1.0
oscarandersson8218 Apr 10, 2025
575afda
Arm backend: Support for tanh for TOSA 1.0
oscarandersson8218 Apr 10, 2025
f10eb69
Arm backend: Support for upsample_nearest TOSA 1.0
oscarandersson8218 Apr 10, 2025
3358e9b
Arm backend: Support for where.self for TOSA 1.0
oscarandersson8218 Apr 10, 2025
861b4f7
Arm backend: Support for binary ops for TOSA 1.0
oscarandersson8218 Apr 11, 2025
3a54bdd
Arm backend: Support for unary ops for TOSA 1.0
oscarandersson8218 Apr 11, 2025
804866f
Arm backend: Support for avg_pool2d for TOSA 1.0
oscarandersson8218 Apr 8, 2025
9f3a16d
Arm backend: Support for max_pool2d for TOSA 1.0
oscarandersson8218 Apr 8, 2025
0a06a5e
Arm backend: Support for constant_pad for TOSA 1.0
oscarandersson8218 Apr 8, 2025
c1633ed
Arm backend: Support for sin and cos for TOSA 1.0
oscarandersson8218 Apr 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/arm/_passes/insert_table_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class TableOps:
exir_ops.edge.aten.reciprocal.default: torch.reciprocal,
exir_ops.edge.aten.rsqrt.default: torch.rsqrt,
exir_ops.edge.aten.sigmoid.default: torch.sigmoid,
exir_ops.edge.aten.cos.default: torch.cos,
exir_ops.edge.aten.sin.default: torch.sin,
exir_ops.edge.aten.tanh.default: torch.tanh,
exir_ops.edge.aten.hardsigmoid.default: torch.nn.functional.hardsigmoid,
exir_ops.edge.aten.hardswish.default: torch.nn.functional.hardswish,
Expand Down
1 change: 1 addition & 0 deletions backends/arm/operator_support/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
pool_2d_support,
reduce_sum_support,
right_shift_support,
sin_cos_support,
slice_copy_support,
to_copy_support,
tosa_supported_operators,
Expand Down
32 changes: 32 additions & 0 deletions backends/arm/operator_support/sin_cos_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe


import torch.fx as fx
from executorch.backends.arm.operator_support.tosa_supported_operators import (
register_tosa_support_check,
SupportedTOSAOperatorCheck,
)
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.exir.dialects._ops import ops as exir_ops


@register_tosa_support_check
class SinCosSupported(SupportedTOSAOperatorCheck):
targets = [
exir_ops.edge.aten.cos.default,
exir_ops.edge.aten.sin.default,
]

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80+BI"),
TosaSpecification.create_from_string("TOSA-1.0+INT"),
TosaSpecification.create_from_string("TOSA-1.0+FP"),
]

def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
return True
10 changes: 8 additions & 2 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@
EthosU55NotSupported,
EthosU55TransposeCheck,
)
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
from executorch.backends.arm.tosa_specification import (
Tosa_0_80,
Tosa_1_00,
TosaSpecification,
)
from executorch.exir import ExportedProgram
from executorch.exir.backend.utils import WhyNoPartitionReporter
from executorch.exir.dialects._ops import ops as exir_ops
Expand Down Expand Up @@ -124,7 +128,9 @@ def tosa_support_factory(
if not tosa_spec.support_float():
negative_checks.append(NeedsDecompositionCheck(reporter))
negative_checks.append(CheckProperQuantization(reporter))
if isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset:
if (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset) or (
isinstance(tosa_spec, Tosa_1_00) and "u55" in tosa_spec.extensions
):
negative_checks.append(EthosU55NotSupported(reporter))
negative_checks.append(EthosU55DtypeSupport(reporter))
negative_checks.append(EthosU55TransposeCheck(reporter))
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
op_clamp,
op_constant_pad_nd,
op_conv2d,
op_cos,
op_eq,
op_erf,
op_exp,
Expand All @@ -38,6 +39,7 @@
op_rshift_tensor,
op_rsqrt,
op_sigmoid,
op_sin,
op_slice,
op_sub,
op_sum,
Expand Down
49 changes: 45 additions & 4 deletions backends/arm/operators/op_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
from typing import cast, List
from typing import Any, cast, List

import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.operators.node_visitor import ( # type: ignore
NodeVisitor,
register_node_visitor,
Expand All @@ -17,16 +16,19 @@


@register_node_visitor
class AnyVisitor(NodeVisitor):
class AnyVisitor_0_80(NodeVisitor):
target = "aten.any.dim"

tosa_specs = NodeVisitor.tosa_specs_0_80

def define_node(
self,
node: Node,
tosa_graph: ts.TosaSerializer,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

if not (inputs[0].dtype == output.dtype):
raise ValueError(
Expand All @@ -50,3 +52,42 @@ def define_node(
tosa_graph.addOperator(
ts.TosaOp.Op().REDUCE_ANY, [inputs[0].name], [output.name], attr
)


@register_node_visitor
class AnyVisitor(NodeVisitor):
target = "aten.any.dim"

tosa_specs = NodeVisitor.tosa_specs_1_00

def define_node(
self,
node: Node,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
import serializer.tosa_serializer as ts

if not (inputs[0].dtype == output.dtype):
raise ValueError(
"All inputs and outputs need same dtype."
f"Got {ts.DTypeNames[inputs[0].dtype]=}, {ts.DTypeNames[output.dtype]=}."
)
if not (inputs[0].dtype == ts.DType.BOOL):
raise ValueError("All inputs need to be BOOL." f"Got {inputs[0].dtype=}")

input_shape = list(inputs[0].shape)
dim = cast(int, inputs[1].number) % len(
input_shape
) # process the negative index
keep_dim = cast(bool, inputs[2].number if len(inputs) > 2 else False)
if not keep_dim:
raise ValueError("This case should be handled by ConvertAnyDimDimsPass")

attr = ts.TosaSerializerAttribute()
attr.ReduceAnyAttribute(inputs[0].dim_order.index(dim))

tosa_graph.addOperator(
ts.TosaOp.Op().REDUCE_ANY, [inputs[0].name], [output.name], attr
)
141 changes: 134 additions & 7 deletions backends/arm/operators/op_avg_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
from typing import List
from typing import Any, List

import torch

import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
get_input_qparams,
get_output_qparams,
Expand All @@ -36,14 +34,16 @@ def __init__(self, *args):
def _build_generic_avgpool2d(
self,
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
input_zp: int,
output_zp: int,
accumulator_type: ts.DType,
accumulator_type: Any,
) -> None:

import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

input_tensor = inputs[0]
kernel_size_list = inputs[1].special
stride_size_list = inputs[2].special
Expand Down Expand Up @@ -79,10 +79,12 @@ def _build_generic_avgpool2d(
def define_node(
self,
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

input_tensor = inputs[0]
assert input_tensor.dtype == ts.DType.INT8

Expand Down Expand Up @@ -110,10 +112,135 @@ class AvgPool2dVisitor_0_80_MI(AvgPool2dVisitor_0_80_BI):
def define_node(
self,
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

assert (
inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.FP32
), "Only FP32 and INT8 supported"

if inputs[0].dtype == ts.DType.INT8:
super().define_node(node, tosa_graph, inputs, output)

if inputs[0].dtype == ts.DType.FP32:
accumulator_type = ts.DType.FP32
# Initilize zero point to zero.
input_zp = 0
output_zp = 0

self._build_generic_avgpool2d(
node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type
)


@register_node_visitor
class AvgPool2dVisitor(NodeVisitor):
target = "aten.avg_pool2d.default"

tosa_specs = [
TosaSpecification.create_from_string("TOSA-1.0+INT"),
]

def __init__(self, *args):
super().__init__(*args)

def _build_generic_avgpool2d(
self,
node: torch.fx.Node,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
input_zp: int,
output_zp: int,
accumulator_type: Any,
) -> None:

import serializer.tosa_serializer as ts # type: ignore

input_tensor = inputs[0]
kernel_size_list = inputs[1].special
stride_size_list = inputs[2].special

try:
pad_size_list = inputs[3].special
pad_size_list = [
pad_size_list[0],
pad_size_list[0],
pad_size_list[1],
pad_size_list[1],
]
except IndexError:
pad_size_list = [0, 0, 0, 0]

attr = ts.TosaSerializerAttribute()
attr.AvgPool2dAttribute(
kernel=kernel_size_list,
stride=stride_size_list,
pad=pad_size_list,
acc_type=accumulator_type,
)
input_zp_tensor = tosa_graph.addConst(
shape=[1], dtype=output.dtype, vals=[input_zp]
)
output_zp_tensor = tosa_graph.addConst(
shape=[1], dtype=output.dtype, vals=[output_zp]
)

tosa_graph.addOperator(
ts.TosaOp.Op().AVG_POOL2D,
[input_tensor.name, input_zp_tensor.name, output_zp_tensor.name],
[output.name],
attr,
)

def define_node(
self,
node: torch.fx.Node,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
import serializer.tosa_serializer as ts # type: ignore

input_tensor = inputs[0]
assert input_tensor.dtype == ts.DType.INT8

accumulator_type = ts.DType.INT32

input_qargs = get_input_qparams(node)
input_zp = input_qargs[0].zp

output_qargs = get_output_qparams(node)
output_zp = output_qargs[0].zp

self._build_generic_avgpool2d(
node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type
)


@register_node_visitor
class AvgPool2dVisitor_FP(AvgPool2dVisitor):
target = "aten.avg_pool2d.default"

tosa_specs = [
TosaSpecification.create_from_string("TOSA-1.0+FP"),
]

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
import serializer.tosa_serializer as ts # type: ignore

assert (
inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.FP32
), "Only FP32 and INT8 supported"
Expand Down
Loading
Loading