Skip to content

Revert "Arm backend: Update more node visitors to support TOSA 1.0" #10455

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions backends/arm/_passes/insert_table_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ class TableOps:
exir_ops.edge.aten.reciprocal.default: torch.reciprocal,
exir_ops.edge.aten.rsqrt.default: torch.rsqrt,
exir_ops.edge.aten.sigmoid.default: torch.sigmoid,
exir_ops.edge.aten.cos.default: torch.cos,
exir_ops.edge.aten.sin.default: torch.sin,
exir_ops.edge.aten.tanh.default: torch.tanh,
exir_ops.edge.aten.hardsigmoid.default: torch.nn.functional.hardsigmoid,
exir_ops.edge.aten.hardswish.default: torch.nn.functional.hardswish,
Expand Down
1 change: 0 additions & 1 deletion backends/arm/operator_support/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
pool_2d_support,
reduce_sum_support,
right_shift_support,
sin_cos_support,
slice_copy_support,
to_copy_support,
tosa_supported_operators,
Expand Down
32 changes: 0 additions & 32 deletions backends/arm/operator_support/sin_cos_support.py

This file was deleted.

10 changes: 2 additions & 8 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,7 @@
EthosU55NotSupported,
EthosU55TransposeCheck,
)
from executorch.backends.arm.tosa_specification import (
Tosa_0_80,
Tosa_1_00,
TosaSpecification,
)
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
from executorch.exir import ExportedProgram
from executorch.exir.backend.utils import WhyNoPartitionReporter
from executorch.exir.dialects._ops import ops as exir_ops
Expand Down Expand Up @@ -128,9 +124,7 @@ def tosa_support_factory(
if not tosa_spec.support_float():
negative_checks.append(NeedsDecompositionCheck(reporter))
negative_checks.append(CheckProperQuantization(reporter))
if (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset) or (
isinstance(tosa_spec, Tosa_1_00) and "u55" in tosa_spec.extensions
):
if isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset:
negative_checks.append(EthosU55NotSupported(reporter))
negative_checks.append(EthosU55DtypeSupport(reporter))
negative_checks.append(EthosU55TransposeCheck(reporter))
Expand Down
2 changes: 0 additions & 2 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
op_clamp,
op_constant_pad_nd,
op_conv2d,
op_cos,
op_eq,
op_erf,
op_exp,
Expand All @@ -39,7 +38,6 @@
op_rshift_tensor,
op_rsqrt,
op_sigmoid,
op_sin,
op_slice,
op_sub,
op_sum,
Expand Down
49 changes: 4 additions & 45 deletions backends/arm/operators/op_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
from typing import Any, cast, List
from typing import cast, List

import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.operators.node_visitor import ( # type: ignore
NodeVisitor,
register_node_visitor,
Expand All @@ -15,59 +16,17 @@
from torch.fx import Node


@register_node_visitor
class AnyVisitor_0_80(NodeVisitor):
target = "aten.any.dim"

tosa_specs = NodeVisitor.tosa_specs_0_80

def define_node(
self,
node: Node,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

if not (inputs[0].dtype == output.dtype):
raise ValueError(
"All inputs and outputs need same dtype."
f"Got {ts.DTypeNames[inputs[0].dtype]=}, {ts.DTypeNames[output.dtype]=}."
)
if not (inputs[0].dtype == ts.DType.BOOL):
raise ValueError("All inputs need to be BOOL." f"Got {inputs[0].dtype=}")

input_shape = list(inputs[0].shape)
dim = cast(int, inputs[1].number) % len(
input_shape
) # process the negative index
keep_dim = cast(bool, inputs[2].number if len(inputs) > 2 else False)
if not keep_dim:
raise ValueError("This case should be handled by ConvertAnyDimDimsPass")

attr = ts.TosaSerializerAttribute()
attr.AxisAttribute(inputs[0].dim_order.index(dim))

tosa_graph.addOperator(
ts.TosaOp.Op().REDUCE_ANY, [inputs[0].name], [output.name], attr
)


@register_node_visitor
class AnyVisitor(NodeVisitor):
target = "aten.any.dim"

tosa_specs = NodeVisitor.tosa_specs_1_00

def define_node(
self,
node: Node,
tosa_graph: Any,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
import serializer.tosa_serializer as ts

if not (inputs[0].dtype == output.dtype):
raise ValueError(
Expand All @@ -86,7 +45,7 @@ def define_node(
raise ValueError("This case should be handled by ConvertAnyDimDimsPass")

attr = ts.TosaSerializerAttribute()
attr.ReduceAnyAttribute(inputs[0].dim_order.index(dim))
attr.AxisAttribute(inputs[0].dim_order.index(dim))

tosa_graph.addOperator(
ts.TosaOp.Op().REDUCE_ANY, [inputs[0].name], [output.name], attr
Expand Down
141 changes: 7 additions & 134 deletions backends/arm/operators/op_avg_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
from typing import Any, List
from typing import List

import torch

import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
get_input_qparams,
get_output_qparams,
Expand All @@ -34,16 +36,14 @@ def __init__(self, *args):
def _build_generic_avgpool2d(
self,
node: torch.fx.Node,
tosa_graph: Any,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
input_zp: int,
output_zp: int,
accumulator_type: Any,
accumulator_type: ts.DType,
) -> None:

import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

input_tensor = inputs[0]
kernel_size_list = inputs[1].special
stride_size_list = inputs[2].special
Expand Down Expand Up @@ -79,12 +79,10 @@ def _build_generic_avgpool2d(
def define_node(
self,
node: torch.fx.Node,
tosa_graph: Any,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

input_tensor = inputs[0]
assert input_tensor.dtype == ts.DType.INT8

Expand Down Expand Up @@ -112,135 +110,10 @@ class AvgPool2dVisitor_0_80_MI(AvgPool2dVisitor_0_80_BI):
def define_node(
self,
node: torch.fx.Node,
tosa_graph: Any,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

assert (
inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.FP32
), "Only FP32 and INT8 supported"

if inputs[0].dtype == ts.DType.INT8:
super().define_node(node, tosa_graph, inputs, output)

if inputs[0].dtype == ts.DType.FP32:
accumulator_type = ts.DType.FP32
# Initilize zero point to zero.
input_zp = 0
output_zp = 0

self._build_generic_avgpool2d(
node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type
)


@register_node_visitor
class AvgPool2dVisitor(NodeVisitor):
target = "aten.avg_pool2d.default"

tosa_specs = [
TosaSpecification.create_from_string("TOSA-1.0+INT"),
]

def __init__(self, *args):
super().__init__(*args)

def _build_generic_avgpool2d(
self,
node: torch.fx.Node,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
input_zp: int,
output_zp: int,
accumulator_type: Any,
) -> None:

import serializer.tosa_serializer as ts # type: ignore

input_tensor = inputs[0]
kernel_size_list = inputs[1].special
stride_size_list = inputs[2].special

try:
pad_size_list = inputs[3].special
pad_size_list = [
pad_size_list[0],
pad_size_list[0],
pad_size_list[1],
pad_size_list[1],
]
except IndexError:
pad_size_list = [0, 0, 0, 0]

attr = ts.TosaSerializerAttribute()
attr.AvgPool2dAttribute(
kernel=kernel_size_list,
stride=stride_size_list,
pad=pad_size_list,
acc_type=accumulator_type,
)
input_zp_tensor = tosa_graph.addConst(
shape=[1], dtype=output.dtype, vals=[input_zp]
)
output_zp_tensor = tosa_graph.addConst(
shape=[1], dtype=output.dtype, vals=[output_zp]
)

tosa_graph.addOperator(
ts.TosaOp.Op().AVG_POOL2D,
[input_tensor.name, input_zp_tensor.name, output_zp_tensor.name],
[output.name],
attr,
)

def define_node(
self,
node: torch.fx.Node,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
import serializer.tosa_serializer as ts # type: ignore

input_tensor = inputs[0]
assert input_tensor.dtype == ts.DType.INT8

accumulator_type = ts.DType.INT32

input_qargs = get_input_qparams(node)
input_zp = input_qargs[0].zp

output_qargs = get_output_qparams(node)
output_zp = output_qargs[0].zp

self._build_generic_avgpool2d(
node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type
)


@register_node_visitor
class AvgPool2dVisitor_FP(AvgPool2dVisitor):
target = "aten.avg_pool2d.default"

tosa_specs = [
TosaSpecification.create_from_string("TOSA-1.0+FP"),
]

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
import serializer.tosa_serializer as ts # type: ignore

assert (
inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.FP32
), "Only FP32 and INT8 supported"
Expand Down
Loading
Loading