|
| 1 | +# Copyright 2025 Arm Limited and/or its affiliates. |
| 2 | +# |
| 3 | +# This source code is licensed under the BSD-style license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | + |
| 6 | +import logging |
| 7 | + |
| 8 | +import torch._export.utils |
| 9 | +from executorch.backends.arm._passes.arm_pass_utils import ( |
| 10 | + get_constant_placeholder_kind, |
| 11 | + get_param_tensor, |
| 12 | + is_persistent_buffer, |
| 13 | +) |
| 14 | +from executorch.backends.transforms.utils import ( |
| 15 | + create_constant_placeholder, |
| 16 | + delete_constant_placeholder, |
| 17 | +) |
| 18 | +from executorch.exir import ExportedProgram |
| 19 | +from executorch.exir.dialects._ops import ops as exir_ops |
| 20 | +from executorch.exir.pass_base import ExportPass, PassResult |
| 21 | + |
| 22 | +logger = logging.getLogger(__name__) |
| 23 | + |
| 24 | + |
| 25 | +class FuseConstantOpsPass(ExportPass): |
| 26 | + """ |
| 27 | + Fuses ops with only placeholder parameters into one placeholder parameter node with the op |
| 28 | + pre-calulcated on its data. |
| 29 | +
|
| 30 | + Original: |
| 31 | + state_dict = {x_tensor_name : data} |
| 32 | + def f(): |
| 33 | + return x.view(...) |
| 34 | +
|
| 35 | + After pass: |
| 36 | + state_dict = {x_tensor_name_fused_const : data.view(...)} |
| 37 | + def f(): |
| 38 | + return x |
| 39 | + """ |
| 40 | + |
| 41 | + def __init__(self, exported_program: ExportedProgram) -> None: |
| 42 | + super().__init__() |
| 43 | + self.exported_program = exported_program |
| 44 | + |
| 45 | + def fuse_nodes(self, node) -> bool: |
| 46 | + """ |
| 47 | + Takes a node with only parameter inputs and replaces it with one constant tensor node with |
| 48 | + the operations already carried out on the data. |
| 49 | + """ |
| 50 | + |
| 51 | + if node.target == exir_ops.edge.aten.full.default: |
| 52 | + # Create data from args |
| 53 | + size, fill_value = node.args |
| 54 | + dtype = node.kwargs["dtype"] |
| 55 | + data = torch.full(size, float(fill_value), dtype=dtype) |
| 56 | + |
| 57 | + insert_pos = list(node.graph.nodes)[0] |
| 58 | + else: |
| 59 | + # Extract tensors and args from the node |
| 60 | + |
| 61 | + if len(node.all_input_nodes) == 0: |
| 62 | + raise RuntimeError("No inputs found") |
| 63 | + |
| 64 | + data_list = [ |
| 65 | + get_param_tensor(self.exported_program, input_node) |
| 66 | + for input_node in node.all_input_nodes |
| 67 | + ] |
| 68 | + |
| 69 | + args = node.args[len(node.all_input_nodes) :] |
| 70 | + kwargs = node.kwargs |
| 71 | + |
| 72 | + if "input_qparams" in node.meta and len(node.meta["input_qparams"]) > 0: |
| 73 | + dequantize_op = ( |
| 74 | + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default |
| 75 | + ) |
| 76 | + |
| 77 | + for i in range(len(node.all_input_nodes)): |
| 78 | + q_params = node.meta["input_qparams"][i] |
| 79 | + data_list[i] = dequantize_op( |
| 80 | + data_list[i], |
| 81 | + q_params.scale, |
| 82 | + q_params.zp, |
| 83 | + q_params.qmin, |
| 84 | + q_params.qmax, |
| 85 | + q_params.dtype, |
| 86 | + ) |
| 87 | + |
| 88 | + # Run the op on the extracted tensor |
| 89 | + data = node.target(*data_list, *args, **kwargs) |
| 90 | + |
| 91 | + if "output_qparams" in node.meta and len(node.meta["output_qparams"]) > 0: |
| 92 | + quantize_op = ( |
| 93 | + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default |
| 94 | + ) |
| 95 | + q_params = node.meta["output_qparams"][0] |
| 96 | + data = quantize_op( |
| 97 | + data, |
| 98 | + q_params.scale, |
| 99 | + q_params.zp, |
| 100 | + q_params.qmin, |
| 101 | + q_params.qmax, |
| 102 | + q_params.dtype, |
| 103 | + ) |
| 104 | + |
| 105 | + insert_pos = list(node.all_input_nodes)[0] |
| 106 | + |
| 107 | + # Make new node the same kind as the first constant input |
| 108 | + input_kind = get_constant_placeholder_kind(self.exported_program, insert_pos) |
| 109 | + persistent_buffer = is_persistent_buffer(self.exported_program, insert_pos) |
| 110 | + |
| 111 | + # Create new node |
| 112 | + with node.graph.inserting_before(insert_pos): |
| 113 | + const_node = create_constant_placeholder( |
| 114 | + exp_program=self.exported_program, |
| 115 | + graph=node.graph, |
| 116 | + kind=input_kind, |
| 117 | + name=node.name + "_fused_const", |
| 118 | + data=data, |
| 119 | + persistent_buffer=persistent_buffer, |
| 120 | + ) |
| 121 | + |
| 122 | + node.replace_all_uses_with(const_node) |
| 123 | + |
| 124 | + return True |
| 125 | + |
| 126 | + def call(self, graph_module): |
| 127 | + modified = True |
| 128 | + input_nodes_to_delete = [] |
| 129 | + for node in graph_module.graph.nodes: |
| 130 | + if node.op != "call_function": |
| 131 | + continue |
| 132 | + if node.target == torch.ops.tosa._table.default: |
| 133 | + continue |
| 134 | + if node.target == exir_ops.edge.aten.repeat.default: |
| 135 | + _, multiples = node.args |
| 136 | + # Do not fuse if the repeat creates a larger output, i.e. any multiple > 1 |
| 137 | + if any((multiple > 1 for multiple in multiples)): |
| 138 | + continue |
| 139 | + |
| 140 | + input_nodes = node.all_input_nodes |
| 141 | + input_nodes_constant = ( |
| 142 | + torch._export.utils.is_param(self.exported_program, input_node) |
| 143 | + or torch._export.utils.is_lifted_tensor_constant( |
| 144 | + self.exported_program, input_node |
| 145 | + ) |
| 146 | + or torch._export.utils.is_buffer(self.exported_program, input_node) |
| 147 | + for input_node in input_nodes |
| 148 | + ) |
| 149 | + input_nodes_single_users = ( |
| 150 | + len(input_node.users) == 1 for input_node in input_nodes |
| 151 | + ) |
| 152 | + |
| 153 | + if all(input_nodes_constant) and all(input_nodes_single_users): |
| 154 | + try: |
| 155 | + self.fuse_nodes(node) |
| 156 | + graph_module.recompile() # Recompile needed to catch chains of constant ops |
| 157 | + input_nodes_to_delete.extend(input_nodes) |
| 158 | + except Exception as e: |
| 159 | + logger.warning( |
| 160 | + f"\nFailed to fuse constant op {node.name} due to exception:\n{str(e)}" |
| 161 | + ) |
| 162 | + |
| 163 | + if modified: |
| 164 | + graph_module.graph.eliminate_dead_code() |
| 165 | + for input_node in input_nodes_to_delete: |
| 166 | + delete_constant_placeholder(self.exported_program, input_node) |
| 167 | + |
| 168 | + graph_module = super().call(graph_module).graph_module |
| 169 | + |
| 170 | + return PassResult(graph_module, True) |
0 commit comments