Skip to content

Commit 032ba6c

Browse files
Arm backend: Update more node visitors to support TOSA 1.0 (#10425)
### Summary Updates more node visitors to support TOSA 1.0 specification. ### Test plan Tested through public and internal CI. Signed-off-by: Oscar Andersson <[email protected]>
1 parent d31ef13 commit 032ba6c

32 files changed

+1519
-236
lines changed

backends/arm/_passes/insert_table_ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ class TableOps:
4848
exir_ops.edge.aten.reciprocal.default: torch.reciprocal,
4949
exir_ops.edge.aten.rsqrt.default: torch.rsqrt,
5050
exir_ops.edge.aten.sigmoid.default: torch.sigmoid,
51+
exir_ops.edge.aten.cos.default: torch.cos,
52+
exir_ops.edge.aten.sin.default: torch.sin,
5153
exir_ops.edge.aten.tanh.default: torch.tanh,
5254
exir_ops.edge.aten.hardsigmoid.default: torch.nn.functional.hardsigmoid,
5355
exir_ops.edge.aten.hardswish.default: torch.nn.functional.hardswish,

backends/arm/operator_support/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
pool_2d_support,
1313
reduce_sum_support,
1414
right_shift_support,
15+
sin_cos_support,
1516
slice_copy_support,
1617
to_copy_support,
1718
tosa_supported_operators,
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
9+
import torch.fx as fx
10+
from executorch.backends.arm.operator_support.tosa_supported_operators import (
11+
register_tosa_support_check,
12+
SupportedTOSAOperatorCheck,
13+
)
14+
from executorch.backends.arm.tosa_specification import TosaSpecification
15+
from executorch.exir.dialects._ops import ops as exir_ops
16+
17+
18+
@register_tosa_support_check
19+
class SinCosSupported(SupportedTOSAOperatorCheck):
20+
targets = [
21+
exir_ops.edge.aten.cos.default,
22+
exir_ops.edge.aten.sin.default,
23+
]
24+
25+
tosa_specs = [
26+
TosaSpecification.create_from_string("TOSA-0.80+BI"),
27+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
28+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
29+
]
30+
31+
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
32+
return True

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,11 @@
2323
EthosU55NotSupported,
2424
EthosU55TransposeCheck,
2525
)
26-
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
26+
from executorch.backends.arm.tosa_specification import (
27+
Tosa_0_80,
28+
Tosa_1_00,
29+
TosaSpecification,
30+
)
2731
from executorch.exir import ExportedProgram
2832
from executorch.exir.backend.utils import WhyNoPartitionReporter
2933
from executorch.exir.dialects._ops import ops as exir_ops
@@ -124,7 +128,9 @@ def tosa_support_factory(
124128
if not tosa_spec.support_float():
125129
negative_checks.append(NeedsDecompositionCheck(reporter))
126130
negative_checks.append(CheckProperQuantization(reporter))
127-
if isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset:
131+
if (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset) or (
132+
isinstance(tosa_spec, Tosa_1_00) and "u55" in tosa_spec.extensions
133+
):
128134
negative_checks.append(EthosU55NotSupported(reporter))
129135
negative_checks.append(EthosU55DtypeSupport(reporter))
130136
negative_checks.append(EthosU55TransposeCheck(reporter))

backends/arm/operators/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
op_clamp,
1919
op_constant_pad_nd,
2020
op_conv2d,
21+
op_cos,
2122
op_eq,
2223
op_erf,
2324
op_exp,
@@ -38,6 +39,7 @@
3839
op_rshift_tensor,
3940
op_rsqrt,
4041
op_sigmoid,
42+
op_sin,
4143
op_slice,
4244
op_sub,
4345
op_sum,

backends/arm/operators/op_any.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7-
from typing import cast, List
7+
from typing import Any, cast, List
88

9-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
109
from executorch.backends.arm.operators.node_visitor import ( # type: ignore
1110
NodeVisitor,
1211
register_node_visitor,
@@ -17,16 +16,19 @@
1716

1817

1918
@register_node_visitor
20-
class AnyVisitor(NodeVisitor):
19+
class AnyVisitor_0_80(NodeVisitor):
2120
target = "aten.any.dim"
2221

22+
tosa_specs = NodeVisitor.tosa_specs_0_80
23+
2324
def define_node(
2425
self,
2526
node: Node,
26-
tosa_graph: ts.TosaSerializer,
27+
tosa_graph: Any,
2728
inputs: List[TosaArg],
2829
output: TosaArg,
2930
) -> None:
31+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
3032

3133
if not (inputs[0].dtype == output.dtype):
3234
raise ValueError(
@@ -50,3 +52,42 @@ def define_node(
5052
tosa_graph.addOperator(
5153
ts.TosaOp.Op().REDUCE_ANY, [inputs[0].name], [output.name], attr
5254
)
55+
56+
57+
@register_node_visitor
58+
class AnyVisitor(NodeVisitor):
59+
target = "aten.any.dim"
60+
61+
tosa_specs = NodeVisitor.tosa_specs_1_00
62+
63+
def define_node(
64+
self,
65+
node: Node,
66+
tosa_graph: Any,
67+
inputs: List[TosaArg],
68+
output: TosaArg,
69+
) -> None:
70+
import serializer.tosa_serializer as ts
71+
72+
if not (inputs[0].dtype == output.dtype):
73+
raise ValueError(
74+
"All inputs and outputs need same dtype."
75+
f"Got {ts.DTypeNames[inputs[0].dtype]=}, {ts.DTypeNames[output.dtype]=}."
76+
)
77+
if not (inputs[0].dtype == ts.DType.BOOL):
78+
raise ValueError("All inputs need to be BOOL." f"Got {inputs[0].dtype=}")
79+
80+
input_shape = list(inputs[0].shape)
81+
dim = cast(int, inputs[1].number) % len(
82+
input_shape
83+
) # process the negative index
84+
keep_dim = cast(bool, inputs[2].number if len(inputs) > 2 else False)
85+
if not keep_dim:
86+
raise ValueError("This case should be handled by ConvertAnyDimDimsPass")
87+
88+
attr = ts.TosaSerializerAttribute()
89+
attr.ReduceAnyAttribute(inputs[0].dim_order.index(dim))
90+
91+
tosa_graph.addOperator(
92+
ts.TosaOp.Op().REDUCE_ANY, [inputs[0].name], [output.name], attr
93+
)

backends/arm/operators/op_avg_pool2d.py

Lines changed: 134 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,10 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7-
from typing import List
7+
from typing import Any, List
88

99
import torch
1010

11-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
12-
1311
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
1412
get_input_qparams,
1513
get_output_qparams,
@@ -36,14 +34,16 @@ def __init__(self, *args):
3634
def _build_generic_avgpool2d(
3735
self,
3836
node: torch.fx.Node,
39-
tosa_graph: ts.TosaSerializer,
37+
tosa_graph: Any,
4038
inputs: List[TosaArg],
4139
output: TosaArg,
4240
input_zp: int,
4341
output_zp: int,
44-
accumulator_type: ts.DType,
42+
accumulator_type: Any,
4543
) -> None:
4644

45+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
46+
4747
input_tensor = inputs[0]
4848
kernel_size_list = inputs[1].special
4949
stride_size_list = inputs[2].special
@@ -79,10 +79,12 @@ def _build_generic_avgpool2d(
7979
def define_node(
8080
self,
8181
node: torch.fx.Node,
82-
tosa_graph: ts.TosaSerializer,
82+
tosa_graph: Any,
8383
inputs: List[TosaArg],
8484
output: TosaArg,
8585
) -> None:
86+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
87+
8688
input_tensor = inputs[0]
8789
assert input_tensor.dtype == ts.DType.INT8
8890

@@ -110,10 +112,135 @@ class AvgPool2dVisitor_0_80_MI(AvgPool2dVisitor_0_80_BI):
110112
def define_node(
111113
self,
112114
node: torch.fx.Node,
113-
tosa_graph: ts.TosaSerializer,
115+
tosa_graph: Any,
114116
inputs: List[TosaArg],
115117
output: TosaArg,
116118
) -> None:
119+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
120+
121+
assert (
122+
inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.FP32
123+
), "Only FP32 and INT8 supported"
124+
125+
if inputs[0].dtype == ts.DType.INT8:
126+
super().define_node(node, tosa_graph, inputs, output)
127+
128+
if inputs[0].dtype == ts.DType.FP32:
129+
accumulator_type = ts.DType.FP32
130+
# Initilize zero point to zero.
131+
input_zp = 0
132+
output_zp = 0
133+
134+
self._build_generic_avgpool2d(
135+
node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type
136+
)
137+
138+
139+
@register_node_visitor
140+
class AvgPool2dVisitor(NodeVisitor):
141+
target = "aten.avg_pool2d.default"
142+
143+
tosa_specs = [
144+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
145+
]
146+
147+
def __init__(self, *args):
148+
super().__init__(*args)
149+
150+
def _build_generic_avgpool2d(
151+
self,
152+
node: torch.fx.Node,
153+
tosa_graph: Any,
154+
inputs: List[TosaArg],
155+
output: TosaArg,
156+
input_zp: int,
157+
output_zp: int,
158+
accumulator_type: Any,
159+
) -> None:
160+
161+
import serializer.tosa_serializer as ts # type: ignore
162+
163+
input_tensor = inputs[0]
164+
kernel_size_list = inputs[1].special
165+
stride_size_list = inputs[2].special
166+
167+
try:
168+
pad_size_list = inputs[3].special
169+
pad_size_list = [
170+
pad_size_list[0],
171+
pad_size_list[0],
172+
pad_size_list[1],
173+
pad_size_list[1],
174+
]
175+
except IndexError:
176+
pad_size_list = [0, 0, 0, 0]
177+
178+
attr = ts.TosaSerializerAttribute()
179+
attr.AvgPool2dAttribute(
180+
kernel=kernel_size_list,
181+
stride=stride_size_list,
182+
pad=pad_size_list,
183+
acc_type=accumulator_type,
184+
)
185+
input_zp_tensor = tosa_graph.addConst(
186+
shape=[1], dtype=output.dtype, vals=[input_zp]
187+
)
188+
output_zp_tensor = tosa_graph.addConst(
189+
shape=[1], dtype=output.dtype, vals=[output_zp]
190+
)
191+
192+
tosa_graph.addOperator(
193+
ts.TosaOp.Op().AVG_POOL2D,
194+
[input_tensor.name, input_zp_tensor.name, output_zp_tensor.name],
195+
[output.name],
196+
attr,
197+
)
198+
199+
def define_node(
200+
self,
201+
node: torch.fx.Node,
202+
tosa_graph: Any,
203+
inputs: List[TosaArg],
204+
output: TosaArg,
205+
) -> None:
206+
import serializer.tosa_serializer as ts # type: ignore
207+
208+
input_tensor = inputs[0]
209+
assert input_tensor.dtype == ts.DType.INT8
210+
211+
accumulator_type = ts.DType.INT32
212+
213+
input_qargs = get_input_qparams(node)
214+
input_zp = input_qargs[0].zp
215+
216+
output_qargs = get_output_qparams(node)
217+
output_zp = output_qargs[0].zp
218+
219+
self._build_generic_avgpool2d(
220+
node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type
221+
)
222+
223+
224+
@register_node_visitor
225+
class AvgPool2dVisitor_FP(AvgPool2dVisitor):
226+
target = "aten.avg_pool2d.default"
227+
228+
tosa_specs = [
229+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
230+
]
231+
232+
def __init__(self, *args):
233+
super().__init__(*args)
234+
235+
def define_node(
236+
self,
237+
node: torch.fx.Node,
238+
tosa_graph: Any,
239+
inputs: List[TosaArg],
240+
output: TosaArg,
241+
) -> None:
242+
import serializer.tosa_serializer as ts # type: ignore
243+
117244
assert (
118245
inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.FP32
119246
), "Only FP32 and INT8 supported"

0 commit comments

Comments
 (0)