diff --git a/backends/cadence/aot/pass_utils.py b/backends/cadence/aot/pass_utils.py index 3d73e7f8c1e..9f556135dfb 100644 --- a/backends/cadence/aot/pass_utils.py +++ b/backends/cadence/aot/pass_utils.py @@ -104,6 +104,16 @@ def count_node(graph_module: torch.fx.GraphModule, target: torch.fx.node.Target) return total +def op_counts_match( + graph_module: torch.fx.GraphModule, + expected_op_counts: dict[EdgeOpOverload, int], +) -> bool: + for op, count in expected_op_counts.items(): + if count_node(graph_module, op) != count: + return False + return True + + # Testing utils # Return the compute/function nodes in the graph def get_compute_nodes_in_gm(graph_module: torch.fx.GraphModule) -> List[torch.fx.Node]: diff --git a/backends/cadence/aot/remove_ops.py b/backends/cadence/aot/remove_ops.py index 942f6d55533..3cac7514fff 100644 --- a/backends/cadence/aot/remove_ops.py +++ b/backends/cadence/aot/remove_ops.py @@ -33,7 +33,7 @@ from executorch.backends.cadence.aot.utils import get_edge_overload_packet from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue from executorch.exir.pass_manager import PassManager, PassType from executorch.exir.passes import dead_code_elimination_pass @@ -745,6 +745,68 @@ def permute_shape( return [shape[p] for p in permute_dims] +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class RemoveBranchedQuantDequant(ExportPass): + """ + This pass looks for adjacent quant and dequant nodes with identical + parameters, where the quant node has other users in addition to the + dequant. The quant and dequant pair would be removed by the + FuseQuantDequantToRequantizePass if not for the multiple users. This pass + removes just the dequant node by connecting it to the quant's parent node + """ + + quantize_op_packets: set[EdgeOpOverloadPacket] = { + exir_ops.edge.cadence.quantize_per_tensor, + exir_ops.edge.quantized_decomposed.quantize_per_tensor, + } + dequantize_op_packets: set[EdgeOpOverloadPacket] = { + exir_ops.edge.cadence.dequantize_per_tensor, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor, + } + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + self.remove_branched( + graph_module, self.quantize_op_packets, self.dequantize_op_packets + ) + self.remove_branched( + graph_module, self.dequantize_op_packets, self.quantize_op_packets + ) + + graph_module.graph.eliminate_dead_code() + result = super().call(graph_module) + return result + + def remove_branched( + self, + graph_module: torch.fx.GraphModule, + producer_pkts: set[EdgeOpOverloadPacket], + consumer_pkts: set[EdgeOpOverloadPacket], + ) -> None: + for node in graph_module.graph.nodes: + if ( + node.op != "call_function" + or not isinstance(node.target, EdgeOpOverload) + or get_edge_overload_packet(node.target) not in producer_pkts + ): + continue + + if len(node.users) < 2: + continue + + for user in node.users: + if ( + not isinstance(user.target, EdgeOpOverload) + or get_edge_overload_packet(user.target) not in consumer_pkts + ): + continue + + # check qparams match + if node.args[1:] != user.args[1:]: + continue + + user.replace_all_uses_with(node.args[0]) + + # The following class consolidates functions to remove ops that are redundant # in Jarvis. Currently, each function in this class iterates over each node of # the graph module once. In future, we could consolidate them into a monolithic @@ -765,4 +827,5 @@ class CadenceRemoveNops: RemoveNopMulOpPass, RemoveNopAddOpPass, RemoveNopLinalgVectorNormOpPass, + RemoveBranchedQuantDequant, ] diff --git a/backends/cadence/aot/tests/test_fusion_ops_passes.py b/backends/cadence/aot/tests/test_fusion_ops_passes.py index 792a6ee4166..4af3eafb72a 100644 --- a/backends/cadence/aot/tests/test_fusion_ops_passes.py +++ b/backends/cadence/aot/tests/test_fusion_ops_passes.py @@ -20,7 +20,7 @@ FuseTransposeOpPairsPass, ) from executorch.backends.cadence.aot.graph_builder import GraphBuilder -from executorch.backends.cadence.aot.pass_utils import count_node +from executorch.backends.cadence.aot.pass_utils import count_node, op_counts_match from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload from torch import nn @@ -32,8 +32,7 @@ def check_op_counts( graph_module: torch.fx.GraphModule, expected_op_counts: dict[EdgeOpOverload, int], ) -> None: - for op, count in expected_op_counts.items(): - self.assertEqual(count_node(graph_module, op), count) + self.assertTrue(op_counts_match(graph_module, expected_op_counts)) class TestFusionPasses(TestFusionPassesBase): diff --git a/backends/cadence/aot/tests/test_remove_ops_passes.py b/backends/cadence/aot/tests/test_remove_ops_passes.py index 348e0b5de83..0c802f9cbf5 100644 --- a/backends/cadence/aot/tests/test_remove_ops_passes.py +++ b/backends/cadence/aot/tests/test_remove_ops_passes.py @@ -17,10 +17,11 @@ from executorch.backends.cadence.aot import compiler from executorch.backends.cadence.aot.compiler import export_to_edge -from executorch.backends.cadence.aot.pass_utils import count_node +from executorch.backends.cadence.aot.pass_utils import count_node, op_counts_match from executorch.backends.cadence.aot.quantizer.quantizer import CadenceDefaultQuantizer from executorch.backends.cadence.aot.remove_ops import ( RemoveAliasCopyOpPass, + RemoveBranchedQuantDequant, RemoveCloneOpPass, RemoveContiguousOpPass, RemoveDetachCopyPass, @@ -709,3 +710,34 @@ def forward(self, x): self.assertEqual( count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 2 ) + + def test_remove_dequant_on_branch(self): + class M(torch.nn.Module): + def forward(self, x): + x = torch.abs(x) + x0 = torch.ops.quantized_decomposed.quantize_per_tensor( + x, 1.2, 3, 0, 127, torch.int8 + ) + x1 = torch.abs(x0) + y0 = torch.ops.quantized_decomposed.dequantize_per_tensor( + x0, 1.2, 3, 0, 127, torch.int8 + ) + y1 = y0.view(-1) + return x1, y1 + + inputs = torch.rand(1, 8, 4, 6) + model = M() + graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module + + graph_module = RemoveBranchedQuantDequant()(graph_module).graph_module + self.assertTrue( + op_counts_match( + graph_module, + expected_op_counts={ + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1, + # we expect the pass to remove the dequantize node + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0, + exir_ops.edge.aten.abs.default: 2, + }, + ) + )