Skip to content

Arm backend: Add is_node_supported checks for 4 ops #8209

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions backends/arm/operator_support/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from . import right_shift_support, to_copy_support, tosa_supported_operators # noqa
from . import ( # noqa
convolution_support,
pool_2d_support,
reduce_sum_support,
right_shift_support,
to_copy_support,
tosa_supported_operators,
)
99 changes: 99 additions & 0 deletions backends/arm/operator_support/convolution_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import cast

import torch
import torch.fx as fx
from executorch.backends.arm.operator_support.tosa_supported_operators import (
register_tosa_support_check,
SupportedTOSAOperatorCheck,
)
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
from executorch.exir.dialects._ops import ops as exir_ops


@register_tosa_support_check
class ConvolutionSupported(SupportedTOSAOperatorCheck):
targets = [exir_ops.edge.aten.convolution.default]

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80+BI"),
TosaSpecification.create_from_string("TOSA-0.80+MI"),
]

def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification):

# Not implemented
transposed = cast(bool, node.args[6])
output_padding = cast(list[int], node.args[7])
if transposed:
return False

for pad in output_padding:
if pad != 0:
return False

# Hardware specific constraints
if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset):
return True
else:
return self._is_node_supported_u55(node)

def _is_node_supported_u55(self, node: fx.Node):
"""Hardware constraints for Ethos-U-55 case, Vela 4.2.0 (25.02 release)"""

shape_in = cast(torch.Tensor, node.all_input_nodes[0].meta["val"]).shape
shape_out = node.meta["val"].shape
kernel = cast(fx.Node, node.args[1]).meta["val"].shape
group = cast(int, node.args[8])

C_in = shape_in[1]
C_out = shape_out[1]
if (C_in == group) and (C_out % C_in) == 0:
# Depthwise convolution
for dim in shape_in[1:]:
if not 1 <= dim <= 65536:
return False
else:
# Convolution
if not 1 <= C_in <= 65536:
return False

kernel_w = kernel[2]
kernel_h = kernel[3] if len(kernel) > 3 else 1
# Kernel condition misses constraint on sum of absolute weights
if not 1 <= kernel_h <= 64 or not 1 <= kernel_w * kernel_h <= 4096:
return False

if not self._stride_condition(node):
return False

return True

def _stride_condition(self, node: fx.Node) -> bool:
"""This condition is somewhat complex but boils down
to not supporting stride > 3, unless we have some special conditions.
This condition is a simplified, relaxed version of the hardware constraint,
since the actual constraint requires information not available
here (without a lot of work).

This means that we might accept ops that are not actually supported.
"""
strides = cast(list[int], node.args[3])
has_padding = any(pad > 0 for pad in cast(list[int], node.args[4]))
dilations = cast(list[int], node.args[5])
if len(dilations) == 1:
dilations = [dilations[0]] * 2
if len(strides) == 1:
strides = [strides[0]] * 2

for stride, dilation in zip(strides, dilations):
stride_condition = 1 <= stride <= 3
dilation_condition = (not has_padding) and (dilation == 1)
if (not stride_condition) and (not dilation_condition):
return False

return True
85 changes: 85 additions & 0 deletions backends/arm/operator_support/pool_2d_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import cast

import torch
import torch.fx as fx
from executorch.backends.arm.operator_support.tosa_supported_operators import (
register_tosa_support_check,
SupportedTOSAOperatorCheck,
)
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
from executorch.exir.dialects._ops import ops as exir_ops


def kernel_check(kernel: tuple[int, int]) -> bool:
if not (1 <= kernel[0] * kernel[1] <= 65536):
return False
return 1 <= kernel[1] <= 256


def stride_check(strides: tuple[int, int]) -> bool:
return all(1 <= stride <= 3 for stride in strides)


def dim_check(shape=torch.Size) -> bool:
check = shape[0] == 1
for dim in shape:
check &= 1 <= dim <= 65536
return check


@register_tosa_support_check
class AvgPool2dSupported(SupportedTOSAOperatorCheck):
targets = [
exir_ops.edge.aten.avg_pool2d.default,
]

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80+BI"),
TosaSpecification.create_from_string("TOSA-0.80+MI"),
]

def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset):
return True

# U55 case, Vela 4.2.0 (25.02 release)
shape = cast(torch.Tensor, node.all_input_nodes[0].meta["val"]).shape
kernel = cast(tuple[int, int], node.args[1])
stride = cast(tuple[int, int], node.args[2])
if len(node.args) > 3:
# Padding case
if not all(1 <= k <= 8 for k in kernel):
return False
else:
if not kernel_check(kernel):
return False

return dim_check(shape) and stride_check(stride)


@register_tosa_support_check
class MaxPool2dSupported(SupportedTOSAOperatorCheck):
targets = [
exir_ops.edge.aten.max_pool2d_with_indices.default,
]

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80+BI"),
TosaSpecification.create_from_string("TOSA-0.80+MI"),
]

def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset):
return True

# U55 case, Vela 4.2.0 (25.02 release)
shape = cast(torch.Tensor, node.all_input_nodes[0].meta["val"]).shape
kernel = cast(tuple[int, int], node.args[1])
stride = cast(tuple[int, int], node.args[2])

return kernel_check(kernel) and dim_check(shape) and stride_check(stride)
51 changes: 51 additions & 0 deletions backends/arm/operator_support/reduce_sum_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import cast

import torch.fx as fx
from executorch.backends.arm.operator_support.tosa_supported_operators import (
register_tosa_support_check,
SupportedTOSAOperatorCheck,
)
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
from executorch.exir.dialects._ops import ops as exir_ops


@register_tosa_support_check
class SumSupported(SupportedTOSAOperatorCheck):
targets = [exir_ops.edge.aten.sum.dim_IntList]

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80+BI"),
TosaSpecification.create_from_string("TOSA-0.80+MI"),
]

def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset):
return True

# U55 case, Vela 4.2.0 (25.02 release)
input_shape = node.all_input_nodes[0].meta["val"].shape
dim_list = cast(list[int], node.args[1])
dim_list = [dim % len(input_shape) for dim in dim_list]

for dim in dim_list:
if not 1 <= input_shape[dim] <= 65536:
return False

# We can't be certain of which dim is the last in memory yet,
# Always go for stricter condition.
pre_R_product = 1.0
for length in input_shape[:dim]:
pre_R_product *= length
post_R_product = 1.0
for length in input_shape[dim + 1 :]:
post_R_product *= length
if not 1 <= pre_R_product <= 65536:
return False
if not 1 <= post_R_product <= 65536:
return False
return True
4 changes: 0 additions & 4 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
exir_ops.edge.aten.hardsigmoid.default,
exir_ops.edge.aten.hardtanh.default,
exir_ops.edge.aten.hardswish.default,
exir_ops.edge.aten.convolution.default,
exir_ops.edge.aten.div.Tensor,
exir_ops.edge.aten.eq.Tensor,
exir_ops.edge.aten.exp.default,
Expand All @@ -97,8 +96,6 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
exir_ops.edge.aten.native_layer_norm.default,
exir_ops.edge.aten.avg_pool2d.default,
exir_ops.edge.aten.max_pool2d_with_indices.default,
exir_ops.edge.aten.sigmoid.default,
exir_ops.edge.aten.mean.dim,
exir_ops.edge.aten.mm.default,
Expand All @@ -113,7 +110,6 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
exir_ops.edge.aten._log_softmax.default,
exir_ops.edge.aten.slice_copy.Tensor,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.sum.dim_IntList,
exir_ops.edge.aten.tanh.default,
exir_ops.edge.aten.upsample_nearest2d.vec,
exir_ops.edge.aten.var.correction,
Expand Down
1 change: 0 additions & 1 deletion backends/arm/operators/op_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def define_node(
output: TosaArg,
) -> None:
assert len(node.all_input_nodes) == 1
assert len(node.users) == 1
assert inputs[0].dtype == output.dtype == ts.DType.FP32

tosa_graph.addOperator(TosaOp.Op().LOG, [inputs[0].name], [output.name])
1 change: 0 additions & 1 deletion backends/arm/operators/op_sigmoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def define_node(
) -> None:

assert len(node.all_input_nodes) == 1
assert len(node.users) == 1
assert inputs[0].dtype == output.dtype == ts.DType.FP32

tosa_graph.addOperator(TosaOp.Op().SIGMOID, [inputs[0].name], [output.name])
33 changes: 33 additions & 0 deletions backends/arm/test/ops/test_avg_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,36 @@ def test_avgpool2d_tosa_u85_BI(
common.get_u85_compile_spec(),
(test_data,),
)

reject_data_suite = [
(AvgPool2d(1, 1, 0), torch.rand(2, 5, 5, 5)),
(AvgPool2d((2, 9), 1, 1), torch.rand(1, 16, 5, 32)),
(AvgPool2d(1, 4, 0), torch.rand(1, 10, 10, 10)),
(AvgPool2d((1, 257), 1, 0), torch.rand(1, 16, 5, 300)),
(AvgPool2d((800, 90), 1, 0), torch.rand(1, 16, 850, 100)),
]

@parameterized.expand(reject_data_suite)
def test_reject_avgpool2d_u55_BI(
self,
module: torch.nn.Module,
test_data: torch.tensor,
):
compile_spec = common.get_u55_compile_spec()
tosa_spec = TosaSpecification.create_from_compilespecs(compile_spec)
quantizer = ArmQuantizer(tosa_spec).set_io(get_symmetric_quantization_config())

(
ArmTester(
module,
example_inputs=(test_data,),
compile_spec=compile_spec,
)
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
.export()
.check_count({"torch.ops.aten.avg_pool2d.default": 1})
.check(["torch.ops.quantized_decomposed"])
.to_edge_transform_and_lower()
.check(["executorch_exir_dialects_edge__ops_aten_avg_pool2d_default"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 0})
)
55 changes: 55 additions & 0 deletions backends/arm/test/ops/test_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import torch
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.arm_tester import ArmTester
from executorch.backends.arm.test.tester.test_pipeline import (
EthosU55PipelineBI,
EthosU85PipelineBI,
Expand Down Expand Up @@ -406,3 +407,57 @@ def test_conv2d_u85_BI_on_fvp(test_module):
test_module, test_module.get_inputs(), aten_op, exir_op, run_on_fvp=True
)
pipeline.run()


reject_suite = {
"large_stride": Conv2d(
in_channels=1,
out_channels=1,
kernel_size=(2, 4),
stride=(2, 4),
padding=1,
width=10,
height=14,
batches=1,
),
"large_kernel_height": Conv2d(
in_channels=1,
out_channels=1,
kernel_size=(2, 65),
stride=(1, 1),
padding=0,
width=70,
height=70,
batches=1,
),
"large_kernel": Conv2d(
in_channels=1,
out_channels=1,
kernel_size=(70, 60),
stride=(1,),
padding=0,
width=80,
height=80,
batches=1,
),
}


@common.parametrize("module", reject_suite)
def test_reject_conv2d_u55_BI(
module: Conv2d,
):
(
ArmTester(
module,
example_inputs=module.get_inputs(),
compile_spec=common.get_u55_compile_spec(),
)
.quantize()
.export()
.check_count({"torch.ops.aten.conv2d.default": 1})
.check(["torch.ops.quantized_decomposed"])
.to_edge_transform_and_lower()
.check(["executorch_exir_dialects_edge__ops_aten_convolution_default"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 0})
)
Loading
Loading