Skip to content

Commit ca25c7f

Browse files
AdrianLundellZonglin Peng
authored and
Zonglin Peng
committed
Arm backend: Add FuseViewCopyTransform and FuseConstantsPass in arm_pass_manager (#8997)
Add FuseViewCopyTransform and FuseConstantsPass in arm_pass_manager These passes both removes redundant ops from the graph: - FuseViewCopyTransform pass is added from backends/transforms to merge sequential view ops. - FuseConstantOpsPass is created to compute ops with constant inputs AOT - This is not done in cases where the result is a larger tensor, to avoid increasing the constant memory size. - For BI, ops are quantized with the q/dq-ops as to not change the behaviour of the graph. - Pass order is important: the pass must be placed after all passes which may add constant ops, but before the InsertTableOpsPass, since it doesn't handle TOSA _table-ops. Signed-off-by: Adrian Lundell <[email protected]>
1 parent 5193c08 commit ca25c7f

File tree

4 files changed

+319
-2
lines changed

4 files changed

+319
-2
lines changed

backends/arm/_passes/arm_pass_manager.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
RetraceFoldedDtypesPass,
5252
)
5353
from executorch.backends.arm._passes.fuse_batchnorm2d_pass import FuseBatchnorm2DPass
54+
from executorch.backends.arm._passes.fuse_constant_ops_pass import FuseConstantOpsPass
5455
from executorch.backends.arm._passes.fuse_quantized_activation_pass import ( # type: ignore[import-not-found]
5556
FuseQuantizedActivationPass,
5657
)
@@ -78,6 +79,7 @@
7879
UnsqueezeScalarPlaceholdersPass,
7980
)
8081
from executorch.backends.arm.tosa_specification import TosaSpecification
82+
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform
8183

8284
from executorch.backends.transforms.replace_scalar_with_tensor import (
8385
ReplaceScalarWithTensorArgPass,
@@ -114,7 +116,6 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
114116
self.add_pass(QuantizeOperatorArguments())
115117
self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg]
116118
self.add_pass(RetraceFoldedDtypesPass())
117-
self.add_pass(InsertTableOpsPass(exported_program))
118119

119120
self.add_pass(RemoveClonePass())
120121
self.add_pass(SizeAdjustConv2DPass())
@@ -128,8 +129,12 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
128129
self.add_pass(DecomposeSelectPass())
129130
self.add_pass(ConvertSqueezesToViewPass())
130131

132+
self.add_pass(FuseViewCopyTransform())
133+
self.add_pass(FuseConstantOpsPass(exported_program))
134+
self.add_pass(InsertTableOpsPass(exported_program))
131135
self.add_pass(AnnotateChannelsLastDimOrder())
132136
self.add_pass(InsertRescalePass())
137+
133138
return self._transform(exported_program.graph_module)
134139

135140
def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
@@ -155,7 +160,6 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
155160
self.add_pass(QuantizeOperatorArguments())
156161
self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg]
157162
self.add_pass(RetraceFoldedDtypesPass())
158-
self.add_pass(InsertTableOpsPass(exported_program))
159163

160164
self.add_pass(RemoveClonePass())
161165
self.add_pass(SizeAdjustConv2DPass())
@@ -169,6 +173,9 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
169173
self.add_pass(DecomposeSelectPass())
170174
self.add_pass(ConvertSqueezesToViewPass())
171175

176+
self.add_pass(FuseViewCopyTransform())
177+
self.add_pass(FuseConstantOpsPass(exported_program))
178+
self.add_pass(InsertTableOpsPass(exported_program))
172179
self.add_pass(AnnotateChannelsLastDimOrder())
173180
self.add_pass(InsertRescalePass())
174181

backends/arm/_passes/arm_pass_utils.py

+25
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
)
2727
from torch._ops import OpOverload
2828
from torch._subclasses.fake_tensor import FakeTensor
29+
from torch.export.graph_signature import InputKind
2930

3031

3132
def is_get_attr_node(node: torch.fx.Node) -> bool:
@@ -44,6 +45,30 @@ def is_param_node(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool:
4445
)
4546

4647

48+
def get_constant_placeholder_kind(
49+
exp_prog: ExportedProgram, node: torch.fx.Node
50+
) -> InputKind:
51+
if is_param(exp_prog, node):
52+
return InputKind.PARAMETER
53+
if is_buffer(exp_prog, node):
54+
return InputKind.BUFFER
55+
if is_lifted_tensor_constant(exp_prog, node):
56+
return InputKind.CONSTANT_TENSOR
57+
58+
raise RuntimeError("Node is neither PARAMETER, BUFFER nor CONSTANT_TENSOR")
59+
60+
61+
def is_persistent_buffer(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool | None:
62+
if is_buffer(exp_prog, node):
63+
buffer_name = exp_prog.graph_signature.inputs_to_buffers[node.name]
64+
if buffer_name in exp_prog.graph_signature.non_persistent_buffers:
65+
return False
66+
else:
67+
return True
68+
69+
return None
70+
71+
4772
def get_param_tensor(
4873
exp_prog: ExportedProgram, node: torch.fx.Node
4974
) -> Optional[torch.Tensor]:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import logging
7+
8+
import torch._export.utils
9+
from executorch.backends.arm._passes.arm_pass_utils import (
10+
get_constant_placeholder_kind,
11+
get_param_tensor,
12+
is_persistent_buffer,
13+
)
14+
from executorch.backends.transforms.utils import (
15+
create_constant_placeholder,
16+
delete_constant_placeholder,
17+
)
18+
from executorch.exir import ExportedProgram
19+
from executorch.exir.dialects._ops import ops as exir_ops
20+
from executorch.exir.pass_base import ExportPass, PassResult
21+
22+
logger = logging.getLogger(__name__)
23+
24+
25+
class FuseConstantOpsPass(ExportPass):
26+
"""
27+
Fuses ops with only placeholder parameters into one placeholder parameter node with the op
28+
pre-calulcated on its data.
29+
30+
Original:
31+
state_dict = {x_tensor_name : data}
32+
def f():
33+
return x.view(...)
34+
35+
After pass:
36+
state_dict = {x_tensor_name_fused_const : data.view(...)}
37+
def f():
38+
return x
39+
"""
40+
41+
def __init__(self, exported_program: ExportedProgram) -> None:
42+
super().__init__()
43+
self.exported_program = exported_program
44+
45+
def fuse_nodes(self, node) -> bool:
46+
"""
47+
Takes a node with only parameter inputs and replaces it with one constant tensor node with
48+
the operations already carried out on the data.
49+
"""
50+
51+
if node.target == exir_ops.edge.aten.full.default:
52+
# Create data from args
53+
size, fill_value = node.args
54+
dtype = node.kwargs["dtype"]
55+
data = torch.full(size, float(fill_value), dtype=dtype)
56+
57+
insert_pos = list(node.graph.nodes)[0]
58+
else:
59+
# Extract tensors and args from the node
60+
61+
if len(node.all_input_nodes) == 0:
62+
raise RuntimeError("No inputs found")
63+
64+
data_list = [
65+
get_param_tensor(self.exported_program, input_node)
66+
for input_node in node.all_input_nodes
67+
]
68+
69+
args = node.args[len(node.all_input_nodes) :]
70+
kwargs = node.kwargs
71+
72+
if "input_qparams" in node.meta and len(node.meta["input_qparams"]) > 0:
73+
dequantize_op = (
74+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
75+
)
76+
77+
for i in range(len(node.all_input_nodes)):
78+
q_params = node.meta["input_qparams"][i]
79+
data_list[i] = dequantize_op(
80+
data_list[i],
81+
q_params.scale,
82+
q_params.zp,
83+
q_params.qmin,
84+
q_params.qmax,
85+
q_params.dtype,
86+
)
87+
88+
# Run the op on the extracted tensor
89+
data = node.target(*data_list, *args, **kwargs)
90+
91+
if "output_qparams" in node.meta and len(node.meta["output_qparams"]) > 0:
92+
quantize_op = (
93+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
94+
)
95+
q_params = node.meta["output_qparams"][0]
96+
data = quantize_op(
97+
data,
98+
q_params.scale,
99+
q_params.zp,
100+
q_params.qmin,
101+
q_params.qmax,
102+
q_params.dtype,
103+
)
104+
105+
insert_pos = list(node.all_input_nodes)[0]
106+
107+
# Make new node the same kind as the first constant input
108+
input_kind = get_constant_placeholder_kind(self.exported_program, insert_pos)
109+
persistent_buffer = is_persistent_buffer(self.exported_program, insert_pos)
110+
111+
# Create new node
112+
with node.graph.inserting_before(insert_pos):
113+
const_node = create_constant_placeholder(
114+
exp_program=self.exported_program,
115+
graph=node.graph,
116+
kind=input_kind,
117+
name=node.name + "_fused_const",
118+
data=data,
119+
persistent_buffer=persistent_buffer,
120+
)
121+
122+
node.replace_all_uses_with(const_node)
123+
124+
return True
125+
126+
def call(self, graph_module):
127+
modified = True
128+
input_nodes_to_delete = []
129+
for node in graph_module.graph.nodes:
130+
if node.op != "call_function":
131+
continue
132+
if node.target == torch.ops.tosa._table.default:
133+
continue
134+
if node.target == exir_ops.edge.aten.repeat.default:
135+
_, multiples = node.args
136+
# Do not fuse if the repeat creates a larger output, i.e. any multiple > 1
137+
if any((multiple > 1 for multiple in multiples)):
138+
continue
139+
140+
input_nodes = node.all_input_nodes
141+
input_nodes_constant = (
142+
torch._export.utils.is_param(self.exported_program, input_node)
143+
or torch._export.utils.is_lifted_tensor_constant(
144+
self.exported_program, input_node
145+
)
146+
or torch._export.utils.is_buffer(self.exported_program, input_node)
147+
for input_node in input_nodes
148+
)
149+
input_nodes_single_users = (
150+
len(input_node.users) == 1 for input_node in input_nodes
151+
)
152+
153+
if all(input_nodes_constant) and all(input_nodes_single_users):
154+
try:
155+
self.fuse_nodes(node)
156+
graph_module.recompile() # Recompile needed to catch chains of constant ops
157+
input_nodes_to_delete.extend(input_nodes)
158+
except Exception as e:
159+
logger.warning(
160+
f"\nFailed to fuse constant op {node.name} due to exception:\n{str(e)}"
161+
)
162+
163+
if modified:
164+
graph_module.graph.eliminate_dead_code()
165+
for input_node in input_nodes_to_delete:
166+
delete_constant_placeholder(self.exported_program, input_node)
167+
168+
graph_module = super().call(graph_module).graph_module
169+
170+
return PassResult(graph_module, True)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import operator
7+
from typing import Tuple
8+
9+
import torch
10+
from executorch.backends.arm._passes.fuse_constant_ops_pass import FuseConstantOpsPass
11+
from executorch.backends.arm.test import common
12+
from executorch.backends.arm.test.tester.test_pipeline import (
13+
PassPipeline,
14+
TosaPipelineBI,
15+
)
16+
17+
input_t = Tuple[torch.Tensor] # Input x
18+
19+
20+
class FuseParameter(torch.nn.Module):
21+
ops_before_pass = {
22+
"executorch_exir_dialects_edge__ops_aten_full_default": 1,
23+
"executorch_exir_dialects_edge__ops_aten_view_copy_default": 2,
24+
"executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1,
25+
"executorch_exir_dialects_edge__ops_aten_addmm_default": 1,
26+
"executorch_exir_dialects_edge__ops_aten_add_Tensor": 1,
27+
}
28+
ops_after_pass = {"executorch_exir_dialects_edge__ops_aten_add_Tensor": 1}
29+
ops_not_after_pass = [
30+
"executorch_exir_dialects_edge__ops_aten_full_default",
31+
"executorch_exir_dialects_edge__ops_aten_view_copy_default",
32+
"executorch_exir_dialects_edge__ops_aten_permute_copy_default",
33+
"executorch_exir_dialects_edge__ops_aten_addmm_default",
34+
]
35+
36+
def __init__(
37+
self,
38+
in_features: int = 1,
39+
out_features: int = 1,
40+
bias: bool = True,
41+
):
42+
super().__init__()
43+
self.fc = torch.nn.Linear(
44+
in_features=in_features,
45+
out_features=out_features,
46+
bias=bias,
47+
)
48+
49+
def forward(self, x):
50+
return self.fc(torch.ones(1)) + x
51+
52+
53+
class FuseBuffer(torch.nn.Module):
54+
ops_before_pass = {
55+
"executorch_exir_dialects_edge__ops_aten_add_Tensor": 1,
56+
"executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1,
57+
}
58+
ops_after_pass = {
59+
"executorch_exir_dialects_edge__ops_aten_add_Tensor": 1,
60+
"executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1,
61+
}
62+
ops_not_after_pass = [
63+
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default"
64+
]
65+
66+
def forward(self, x: torch.Tensor):
67+
return (x + 1) * 2
68+
69+
70+
class FuseLiftedTensor(torch.nn.Module):
71+
ops_before_pass = {
72+
"executorch_exir_dialects_edge__ops_aten_select_copy_int": 1,
73+
"executorch_exir_dialects_edge__ops_aten_add_Tensor": 1,
74+
}
75+
ops_after_pass = {"executorch_exir_dialects_edge__ops_aten_add_Tensor": 1}
76+
ops_not_after_pass = ["executorch_exir_dialects_edge__ops_aten_select_copy_int"]
77+
78+
def __init__(
79+
self,
80+
):
81+
super().__init__()
82+
self.lifted_tensor = torch.rand(2)
83+
84+
def forward(self, x: torch.Tensor) -> torch.Tensor:
85+
sliced = self.lifted_tensor[0]
86+
return operator.add(sliced, x)
87+
88+
89+
modules = {
90+
"fuse_parameter": FuseParameter(),
91+
"fuse_buffer": FuseBuffer(),
92+
"fuse_const_tensor": FuseLiftedTensor(),
93+
}
94+
95+
96+
@common.parametrize("module", modules)
97+
def test_fuse_batchnorm_tosa_MI(module):
98+
pipeline = PassPipeline[input_t](
99+
module=module,
100+
test_data=(torch.rand(1),),
101+
tosa_version="TOSA-0.80+MI",
102+
ops_before_pass=module.ops_before_pass,
103+
ops_after_pass=module.ops_after_pass,
104+
ops_not_after_pass=module.ops_not_after_pass,
105+
passes_with_exported_program=[FuseConstantOpsPass],
106+
)
107+
pipeline.run()
108+
109+
110+
@common.parametrize("module", modules)
111+
def test_fuse_batchnorm_tosa_BI(module):
112+
pipeline = TosaPipelineBI[input_t](
113+
module, (torch.rand(10, 10),), [], [], use_to_edge_transform_and_lower=True
114+
)
115+
pipeline.run()

0 commit comments

Comments
 (0)