Skip to content

Commit 6556991

Browse files
Erik-Lundellfacebook-github-bot
authored andcommitted
Add full op for Arm backend (pytorch#4073)
Summary: Implements the full op which creates a tensor of a given shape filled with a given value. The shape and value are set at compile time, i.e. can't be set by a tensor input. Refactors tosa_quant_utils.is_quant_node to handle nodes with no inputs (or outputs) Does not add a full quantizer annotator, the op needs to be quantized by a SharedQuantizationSpec Change-Id: I1cebd1da1af5b9aa726a363431ffc30d8259a0ff Pull Request resolved: pytorch#4073 Reviewed By: mergennachin Differential Revision: D59259731 Pulled By: digantdesai fbshipit-source-id: 621fec994bc2ebc4ad7abd51d9dbf1a5a4deed43
1 parent 908b5a5 commit 6556991

File tree

6 files changed

+247
-19
lines changed

6 files changed

+247
-19
lines changed

backends/arm/arm_partitioner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
4242
exir_ops.edge.aten.hardtanh.default,
4343
exir_ops.edge.aten.convolution.default,
4444
exir_ops.edge.aten.div.Tensor,
45+
exir_ops.edge.aten.full.default,
4546
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
4647
exir_ops.edge.aten.avg_pool2d.default,
4748
exir_ops.edge.aten._softmax.default,

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
op_conv2d,
1313
op_dequant,
1414
op_div,
15+
op_full,
1516
op_get_item,
1617
op_hardtanh,
1718
op_mean_dim,

backends/arm/operators/op_full.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from typing import List
6+
7+
import numpy as np
8+
9+
import serializer.tosa_serializer as ts
10+
from executorch.backends.arm.operators.node_visitor import (
11+
NodeVisitor,
12+
register_node_visitor,
13+
)
14+
from executorch.backends.arm.tosa_mapping import TosaArg
15+
from executorch.backends.arm.tosa_quant_utils import get_quant_node_args
16+
from executorch.backends.arm.tosa_utils import tosa_shape
17+
from torch.fx import Node
18+
19+
20+
@register_node_visitor
21+
class FullVisitor(NodeVisitor):
22+
target = "aten.full.default"
23+
24+
def __init__(self, *args):
25+
super().__init__(*args)
26+
27+
def define_node(
28+
self,
29+
node: Node,
30+
tosa_graph: ts.TosaSerializer,
31+
inputs: List[TosaArg],
32+
output: TosaArg,
33+
is_quant_node: bool,
34+
) -> None:
35+
36+
shape = tosa_shape(inputs[0].special, output.dim_order)
37+
38+
value = inputs[1].number
39+
if is_quant_node:
40+
qargs = get_quant_node_args(list(node.users)[0])
41+
qvalue = np.clip(
42+
np.round(value / qargs.scale) + qargs.zp, qargs.qmin, qargs.qmax
43+
)
44+
dtype = ts.DType.INT8
45+
data = np.full(shape, qvalue, dtype=np.int8)
46+
else:
47+
assert (
48+
output.dtype == ts.DType.FP32
49+
), "'Full' currently only supports FP32 for unquantized models."
50+
dtype = ts.DType.FP32
51+
data = np.full(shape, value, dtype=np.float32)
52+
53+
tosa_graph.addConst(shape, dtype, data, "full-const")
54+
tosa_graph.addOperator(ts.TosaOp.Op.IDENTITY, ["full-const"], [output.name])

backends/arm/test/ops/test_full.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
#
8+
# Tests the full op which creates a tensor of a given shape filled with a given value.
9+
# The shape and value are set at compile time, i.e. can't be set by a tensor input.
10+
#
11+
12+
import unittest
13+
from typing import Tuple
14+
15+
import torch
16+
from executorch.backends.arm.test import common
17+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
18+
from parameterized import parameterized
19+
20+
21+
class TestFull(unittest.TestCase):
22+
class Full(torch.nn.Module):
23+
# A single full op
24+
def forward(self):
25+
return torch.full((3, 3), 4.5)
26+
27+
class AddConstFull(torch.nn.Module):
28+
# Input + a full with constant value.
29+
def forward(self, x: torch.Tensor):
30+
return torch.full((2, 2, 3, 3), 4.5, dtype=torch.float32) + x
31+
32+
class AddVariableFull(torch.nn.Module):
33+
sizes = [
34+
(5),
35+
(5, 5),
36+
(5, 5, 5),
37+
(1, 5, 5, 5),
38+
]
39+
test_parameters = [((torch.randn(n) * 10 - 5, 3.2),) for n in sizes]
40+
41+
def forward(self, x: torch.Tensor, y):
42+
# Input + a full with the shape from the input and a given value 'y'.
43+
return x + torch.full(x.shape, y)
44+
45+
def _test_full_tosa_MI_pipeline(
46+
self,
47+
module: torch.nn.Module,
48+
example_data: Tuple,
49+
test_data: Tuple | None = None,
50+
):
51+
if test_data is None:
52+
test_data = example_data
53+
(
54+
ArmTester(
55+
module,
56+
example_inputs=example_data,
57+
compile_spec=common.get_tosa_compile_spec(),
58+
)
59+
.export()
60+
.check_count({"torch.ops.aten.full.default": 1})
61+
.to_edge()
62+
.partition()
63+
.check_not(["executorch_exir_dialects_edge__ops_aten_full_default"])
64+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
65+
.to_executorch()
66+
.run_method_and_compare_outputs(inputs=test_data)
67+
)
68+
69+
def _test_full_tosa_BI_pipeline(
70+
self,
71+
module: torch.nn.Module,
72+
test_data: Tuple,
73+
permute_memory_to_nhwc: bool,
74+
):
75+
(
76+
ArmTester(
77+
module,
78+
example_inputs=test_data,
79+
compile_spec=common.get_tosa_compile_spec(
80+
permute_memory_to_nhwc=permute_memory_to_nhwc
81+
),
82+
)
83+
.quantize()
84+
.export()
85+
.check_count({"torch.ops.aten.full.default": 1})
86+
.to_edge()
87+
.partition()
88+
.check_not(["executorch_exir_dialects_edge__ops_aten_full_default"])
89+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
90+
.to_executorch()
91+
.run_method_and_compare_outputs(inputs=test_data)
92+
)
93+
94+
def _test_full_tosa_u55_pipeline(self, module: torch.nn.Module, test_data: Tuple):
95+
(
96+
ArmTester(
97+
module,
98+
example_inputs=test_data,
99+
compile_spec=common.get_u55_compile_spec(),
100+
)
101+
.quantize()
102+
.export()
103+
.check_count({"torch.ops.aten.full.default": 1})
104+
.to_edge()
105+
.partition()
106+
.check_not(["executorch_exir_dialects_edge__ops_aten_full_default"])
107+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
108+
.to_executorch()
109+
)
110+
111+
def test_only_full_tosa_MI(self):
112+
self._test_full_tosa_MI_pipeline(self.Full(), ())
113+
114+
def test_const_full_tosa_MI(self):
115+
_input = torch.rand((2, 2, 3, 3)) * 10
116+
self._test_full_tosa_MI_pipeline(self.AddConstFull(), (_input,))
117+
118+
def test_const_full_nhwc_tosa_BI(self):
119+
_input = torch.rand((2, 2, 3, 3)) * 10
120+
self._test_full_tosa_BI_pipeline(self.AddConstFull(), (_input,), True)
121+
122+
@parameterized.expand(AddVariableFull.test_parameters)
123+
def test_full_tosa_MI(self, test_tensor: Tuple):
124+
self._test_full_tosa_MI_pipeline(
125+
self.AddVariableFull(), example_data=test_tensor
126+
)
127+
128+
@parameterized.expand(AddVariableFull.test_parameters)
129+
def test_full_tosa_BI(self, test_tensor: Tuple):
130+
self._test_full_tosa_BI_pipeline(self.AddVariableFull(), test_tensor, False)
131+
132+
@parameterized.expand(AddVariableFull.test_parameters)
133+
def test_full_u55_BI(self, test_tensor: Tuple):
134+
self._test_full_tosa_u55_pipeline(
135+
self.AddVariableFull(),
136+
test_tensor,
137+
)
138+
139+
# This fails since full outputs int64 by default if 'fill_value' is integer, which our backend doesn't support.
140+
@unittest.expectedFailure
141+
def test_integer_value(self):
142+
_input = torch.ones((2, 2))
143+
integer_fill_value = 1
144+
self._test_full_tosa_MI_pipeline(
145+
self.AddVariableFull(), example_data=(_input, integer_fill_value)
146+
)
147+
148+
# This fails since the fill value in the full tensor is set at compile time by the example data (1.).
149+
# Test data tries to set it again at runtime (to 2.) but it doesn't do anything.
150+
# In eager mode, the fill value can be set at runtime, causing the outputs to not match.
151+
@unittest.expectedFailure
152+
def test_set_value_at_runtime(self):
153+
_input = torch.ones((2, 2))
154+
example_fill_value = 1.0
155+
test_fill_value = 2.0
156+
self._test_full_tosa_MI_pipeline(
157+
self.AddVariableFull(),
158+
example_data=(_input, example_fill_value),
159+
test_data=(_input, test_fill_value),
160+
)

backends/arm/test/tester/arm_tester.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -249,8 +249,15 @@ def run_method_and_compare_outputs(
249249
else:
250250
test_input = reference_input
251251

252+
# Test parameters can include constants that are used in eager mode but are already set as attributes
253+
# in TOSA. Therefore, only accept torch.Tensor inputs.
254+
test_input = [
255+
tensor for tensor in test_input if isinstance(tensor, torch.Tensor)
256+
]
257+
252258
input_shapes = [
253-
generated_input.shape for generated_input in reference_input
259+
generated_input.shape if hasattr(generated_input, "shape") else (1,)
260+
for generated_input in reference_input
254261
]
255262
print(f"Run {run_iteration} with input shapes: {input_shapes}")
256263

@@ -274,7 +281,7 @@ def transpose_data_format(
274281
dim_order = (0, 2, 3, 1)
275282
inputs_transposed = list(data)
276283
for i in range(len(data)):
277-
if len(data[i].shape) == 4:
284+
if hasattr(data[i], "shape") and len(data[i].shape) == 4:
278285
inputs_transposed[i] = np.transpose(data[i], dim_order)
279286
return tuple(inputs_transposed)
280287

@@ -298,7 +305,8 @@ def _compare_outputs(
298305
path_to_tosa_files = self.runner_util.intermediate_path
299306

300307
export_stage = self.stages.get(self.stage_name(tester.Export), None)
301-
if export_stage is not None:
308+
quantize_stage = self.stages.get(self.stage_name(tester.Quantize), None)
309+
if export_stage is not None and quantize_stage is not None:
302310
input_names = _get_input_names(export_stage.artifact)
303311
output_node = _get_output_node(export_stage.artifact)
304312
qp_input = _get_input_quantization_params(

backends/arm/tosa_quant_utils.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,22 +27,26 @@ class QuantArgs(NamedTuple):
2727

2828

2929
def is_quant_node(node: torch.fx.Node):
30-
consumer_node = list(node.users)[0]
31-
input = node.all_input_nodes[0]
32-
33-
# For Rank > 2 Linear layers, the quant node is after the view_copy
34-
if (
35-
node.target == exir_ops.edge.aten.addmm.default
36-
and consumer_node.target == exir_ops.edge.aten.view_copy.default
37-
):
38-
consumer_consumer_node = list(consumer_node.users)[0]
39-
return True if consumer_consumer_node.target == q_op else False
40-
41-
return (
42-
consumer_node.target == q_op
43-
or node.target in dq_q_ops
44-
or input.target in dq_q_ops
45-
)
30+
31+
consumer_node_condition = False
32+
if len(list(node.users)) > 0:
33+
consumer_node = list(node.users)[0]
34+
35+
# For Rank > 2 Linear layers, the quant node is after the view_copy
36+
if (
37+
node.target == exir_ops.edge.aten.addmm.default
38+
and consumer_node.target == exir_ops.edge.aten.view_copy.default
39+
):
40+
consumer_consumer_node = list(consumer_node.users)[0]
41+
return True if consumer_consumer_node.target == q_op else False
42+
consumer_node_condition = consumer_node.target == q_op
43+
44+
input_node_condition = False
45+
if len(node.all_input_nodes) > 0:
46+
input = node.all_input_nodes[0]
47+
input_node_condition = input.target in dq_q_ops
48+
49+
return node.target in dq_q_ops or consumer_node_condition or input_node_condition
4650

4751

4852
def get_quant_node_dtype(node: torch.fx.Node):

0 commit comments

Comments
 (0)