From 9d82c074eea85dafccb7ef1d8075d2fae81eed87 Mon Sep 17 00:00:00 2001 From: Adrian Lundell Date: Mon, 3 Feb 2025 13:03:33 +0100 Subject: [PATCH] Add FuseViewCopyTransform and FuseConstantsPass in arm_pass_manager These passes both removes redundant ops from the graph: - FuseViewCopyTransform pass is added from backends/transforms to merge sequential view ops. - FuseConstantOpsPass is created to compute ops with constant inputs AOT - This is not done in cases where the result is a larger tensor, to avoid increasing the constant memory size. - For BI, ops are quantized with the q/dq-ops as to not change the behaviour of the graph. - Pass order is important: the pass must be placed after all passes which may add constant ops, but before the InsertTableOpsPass, since it doesn't handle TOSA _table-ops. Change-Id: I855b2cd969dce24ad6d3c21d9a3f5473ddc984b8 Signed-off-by: Adrian Lundell --- backends/arm/_passes/arm_pass_manager.py | 11 +- backends/arm/_passes/arm_pass_utils.py | 25 +++ .../arm/_passes/fuse_constant_ops_pass.py | 170 ++++++++++++++++++ .../passes/test_fuse_constant_ops_pass.py | 115 ++++++++++++ 4 files changed, 319 insertions(+), 2 deletions(-) create mode 100644 backends/arm/_passes/fuse_constant_ops_pass.py create mode 100644 backends/arm/test/passes/test_fuse_constant_ops_pass.py diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index f8a4a40648f..26ff15db396 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -51,6 +51,7 @@ RetraceFoldedDtypesPass, ) from executorch.backends.arm._passes.fuse_batchnorm2d_pass import FuseBatchnorm2DPass +from executorch.backends.arm._passes.fuse_constant_ops_pass import FuseConstantOpsPass from executorch.backends.arm._passes.fuse_quantized_activation_pass import ( # type: ignore[import-not-found] FuseQuantizedActivationPass, ) @@ -78,6 +79,7 @@ UnsqueezeScalarPlaceholdersPass, ) from executorch.backends.arm.tosa_specification import TosaSpecification +from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform from executorch.backends.transforms.replace_scalar_with_tensor import ( ReplaceScalarWithTensorArgPass, @@ -114,7 +116,6 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(QuantizeOperatorArguments()) self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg] self.add_pass(RetraceFoldedDtypesPass()) - self.add_pass(InsertTableOpsPass(exported_program)) self.add_pass(RemoveClonePass()) self.add_pass(SizeAdjustConv2DPass()) @@ -128,8 +129,12 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(DecomposeSelectPass()) self.add_pass(ConvertSqueezesToViewPass()) + self.add_pass(FuseViewCopyTransform()) + self.add_pass(FuseConstantOpsPass(exported_program)) + self.add_pass(InsertTableOpsPass(exported_program)) self.add_pass(AnnotateChannelsLastDimOrder()) self.add_pass(InsertRescalePass()) + return self._transform(exported_program.graph_module) def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule: @@ -155,7 +160,6 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(QuantizeOperatorArguments()) self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg] self.add_pass(RetraceFoldedDtypesPass()) - self.add_pass(InsertTableOpsPass(exported_program)) self.add_pass(RemoveClonePass()) self.add_pass(SizeAdjustConv2DPass()) @@ -169,6 +173,9 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(DecomposeSelectPass()) self.add_pass(ConvertSqueezesToViewPass()) + self.add_pass(FuseViewCopyTransform()) + self.add_pass(FuseConstantOpsPass(exported_program)) + self.add_pass(InsertTableOpsPass(exported_program)) self.add_pass(AnnotateChannelsLastDimOrder()) self.add_pass(InsertRescalePass()) diff --git a/backends/arm/_passes/arm_pass_utils.py b/backends/arm/_passes/arm_pass_utils.py index 3445886ffa7..a8d06713678 100644 --- a/backends/arm/_passes/arm_pass_utils.py +++ b/backends/arm/_passes/arm_pass_utils.py @@ -26,6 +26,7 @@ ) from torch._ops import OpOverload from torch._subclasses.fake_tensor import FakeTensor +from torch.export.graph_signature import InputKind def is_get_attr_node(node: torch.fx.Node) -> bool: @@ -44,6 +45,30 @@ def is_param_node(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool: ) +def get_constant_placeholder_kind( + exp_prog: ExportedProgram, node: torch.fx.Node +) -> InputKind: + if is_param(exp_prog, node): + return InputKind.PARAMETER + if is_buffer(exp_prog, node): + return InputKind.BUFFER + if is_lifted_tensor_constant(exp_prog, node): + return InputKind.CONSTANT_TENSOR + + raise RuntimeError("Node is neither PARAMETER, BUFFER nor CONSTANT_TENSOR") + + +def is_persistent_buffer(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool | None: + if is_buffer(exp_prog, node): + buffer_name = exp_prog.graph_signature.inputs_to_buffers[node.name] + if buffer_name in exp_prog.graph_signature.non_persistent_buffers: + return False + else: + return True + + return None + + def get_param_tensor( exp_prog: ExportedProgram, node: torch.fx.Node ) -> Optional[torch.Tensor]: diff --git a/backends/arm/_passes/fuse_constant_ops_pass.py b/backends/arm/_passes/fuse_constant_ops_pass.py new file mode 100644 index 00000000000..1fff7d76dfc --- /dev/null +++ b/backends/arm/_passes/fuse_constant_ops_pass.py @@ -0,0 +1,170 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch._export.utils +from executorch.backends.arm._passes.arm_pass_utils import ( + get_constant_placeholder_kind, + get_param_tensor, + is_persistent_buffer, +) +from executorch.backends.transforms.utils import ( + create_constant_placeholder, + delete_constant_placeholder, +) +from executorch.exir import ExportedProgram +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + +logger = logging.getLogger(__name__) + + +class FuseConstantOpsPass(ExportPass): + """ + Fuses ops with only placeholder parameters into one placeholder parameter node with the op + pre-calulcated on its data. + + Original: + state_dict = {x_tensor_name : data} + def f(): + return x.view(...) + + After pass: + state_dict = {x_tensor_name_fused_const : data.view(...)} + def f(): + return x + """ + + def __init__(self, exported_program: ExportedProgram) -> None: + super().__init__() + self.exported_program = exported_program + + def fuse_nodes(self, node) -> bool: + """ + Takes a node with only parameter inputs and replaces it with one constant tensor node with + the operations already carried out on the data. + """ + + if node.target == exir_ops.edge.aten.full.default: + # Create data from args + size, fill_value = node.args + dtype = node.kwargs["dtype"] + data = torch.full(size, float(fill_value), dtype=dtype) + + insert_pos = list(node.graph.nodes)[0] + else: + # Extract tensors and args from the node + + if len(node.all_input_nodes) == 0: + raise RuntimeError("No inputs found") + + data_list = [ + get_param_tensor(self.exported_program, input_node) + for input_node in node.all_input_nodes + ] + + args = node.args[len(node.all_input_nodes) :] + kwargs = node.kwargs + + if "input_qparams" in node.meta and len(node.meta["input_qparams"]) > 0: + dequantize_op = ( + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default + ) + + for i in range(len(node.all_input_nodes)): + q_params = node.meta["input_qparams"][i] + data_list[i] = dequantize_op( + data_list[i], + q_params.scale, + q_params.zp, + q_params.qmin, + q_params.qmax, + q_params.dtype, + ) + + # Run the op on the extracted tensor + data = node.target(*data_list, *args, **kwargs) + + if "output_qparams" in node.meta and len(node.meta["output_qparams"]) > 0: + quantize_op = ( + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default + ) + q_params = node.meta["output_qparams"][0] + data = quantize_op( + data, + q_params.scale, + q_params.zp, + q_params.qmin, + q_params.qmax, + q_params.dtype, + ) + + insert_pos = list(node.all_input_nodes)[0] + + # Make new node the same kind as the first constant input + input_kind = get_constant_placeholder_kind(self.exported_program, insert_pos) + persistent_buffer = is_persistent_buffer(self.exported_program, insert_pos) + + # Create new node + with node.graph.inserting_before(insert_pos): + const_node = create_constant_placeholder( + exp_program=self.exported_program, + graph=node.graph, + kind=input_kind, + name=node.name + "_fused_const", + data=data, + persistent_buffer=persistent_buffer, + ) + + node.replace_all_uses_with(const_node) + + return True + + def call(self, graph_module): + modified = True + input_nodes_to_delete = [] + for node in graph_module.graph.nodes: + if node.op != "call_function": + continue + if node.target == torch.ops.tosa._table.default: + continue + if node.target == exir_ops.edge.aten.repeat.default: + _, multiples = node.args + # Do not fuse if the repeat creates a larger output, i.e. any multiple > 1 + if any((multiple > 1 for multiple in multiples)): + continue + + input_nodes = node.all_input_nodes + input_nodes_constant = ( + torch._export.utils.is_param(self.exported_program, input_node) + or torch._export.utils.is_lifted_tensor_constant( + self.exported_program, input_node + ) + or torch._export.utils.is_buffer(self.exported_program, input_node) + for input_node in input_nodes + ) + input_nodes_single_users = ( + len(input_node.users) == 1 for input_node in input_nodes + ) + + if all(input_nodes_constant) and all(input_nodes_single_users): + try: + self.fuse_nodes(node) + graph_module.recompile() # Recompile needed to catch chains of constant ops + input_nodes_to_delete.extend(input_nodes) + except Exception as e: + logger.warning( + f"\nFailed to fuse constant op {node.name} due to exception:\n{str(e)}" + ) + + if modified: + graph_module.graph.eliminate_dead_code() + for input_node in input_nodes_to_delete: + delete_constant_placeholder(self.exported_program, input_node) + + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, True) diff --git a/backends/arm/test/passes/test_fuse_constant_ops_pass.py b/backends/arm/test/passes/test_fuse_constant_ops_pass.py new file mode 100644 index 00000000000..80d7293607f --- /dev/null +++ b/backends/arm/test/passes/test_fuse_constant_ops_pass.py @@ -0,0 +1,115 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import operator +from typing import Tuple + +import torch +from executorch.backends.arm._passes.fuse_constant_ops_pass import FuseConstantOpsPass +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + PassPipeline, + TosaPipelineBI, +) + +input_t = Tuple[torch.Tensor] # Input x + + +class FuseParameter(torch.nn.Module): + ops_before_pass = { + "executorch_exir_dialects_edge__ops_aten_full_default": 1, + "executorch_exir_dialects_edge__ops_aten_view_copy_default": 2, + "executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1, + "executorch_exir_dialects_edge__ops_aten_addmm_default": 1, + "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, + } + ops_after_pass = {"executorch_exir_dialects_edge__ops_aten_add_Tensor": 1} + ops_not_after_pass = [ + "executorch_exir_dialects_edge__ops_aten_full_default", + "executorch_exir_dialects_edge__ops_aten_view_copy_default", + "executorch_exir_dialects_edge__ops_aten_permute_copy_default", + "executorch_exir_dialects_edge__ops_aten_addmm_default", + ] + + def __init__( + self, + in_features: int = 1, + out_features: int = 1, + bias: bool = True, + ): + super().__init__() + self.fc = torch.nn.Linear( + in_features=in_features, + out_features=out_features, + bias=bias, + ) + + def forward(self, x): + return self.fc(torch.ones(1)) + x + + +class FuseBuffer(torch.nn.Module): + ops_before_pass = { + "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, + "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1, + } + ops_after_pass = { + "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, + "executorch_exir_dialects_edge__ops_aten_mul_Tensor": 1, + } + ops_not_after_pass = [ + "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default" + ] + + def forward(self, x: torch.Tensor): + return (x + 1) * 2 + + +class FuseLiftedTensor(torch.nn.Module): + ops_before_pass = { + "executorch_exir_dialects_edge__ops_aten_select_copy_int": 1, + "executorch_exir_dialects_edge__ops_aten_add_Tensor": 1, + } + ops_after_pass = {"executorch_exir_dialects_edge__ops_aten_add_Tensor": 1} + ops_not_after_pass = ["executorch_exir_dialects_edge__ops_aten_select_copy_int"] + + def __init__( + self, + ): + super().__init__() + self.lifted_tensor = torch.rand(2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + sliced = self.lifted_tensor[0] + return operator.add(sliced, x) + + +modules = { + "fuse_parameter": FuseParameter(), + "fuse_buffer": FuseBuffer(), + "fuse_const_tensor": FuseLiftedTensor(), +} + + +@common.parametrize("module", modules) +def test_fuse_batchnorm_tosa_MI(module): + pipeline = PassPipeline[input_t]( + module=module, + test_data=(torch.rand(1),), + tosa_version="TOSA-0.80+MI", + ops_before_pass=module.ops_before_pass, + ops_after_pass=module.ops_after_pass, + ops_not_after_pass=module.ops_not_after_pass, + passes_with_exported_program=[FuseConstantOpsPass], + ) + pipeline.run() + + +@common.parametrize("module", modules) +def test_fuse_batchnorm_tosa_BI(module): + pipeline = TosaPipelineBI[input_t]( + module, (torch.rand(10, 10),), [], [], use_to_edge_transform_and_lower=True + ) + pipeline.run()