Skip to content

Add a pass to remove certain redundant branched quant/dequant nodes #8896

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions backends/cadence/aot/pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,16 @@ def count_node(graph_module: torch.fx.GraphModule, target: torch.fx.node.Target)
return total


def op_counts_match(
graph_module: torch.fx.GraphModule,
expected_op_counts: dict[EdgeOpOverload, int],
) -> bool:
for op, count in expected_op_counts.items():
if count_node(graph_module, op) != count:
return False
return True


# Testing utils
# Return the compute/function nodes in the graph
def get_compute_nodes_in_gm(graph_module: torch.fx.GraphModule) -> List[torch.fx.Node]:
Expand Down
65 changes: 64 additions & 1 deletion backends/cadence/aot/remove_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket
from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue
from executorch.exir.pass_manager import PassManager, PassType
from executorch.exir.passes import dead_code_elimination_pass
Expand Down Expand Up @@ -745,6 +745,68 @@ def permute_shape(
return [shape[p] for p in permute_dims]


@register_cadence_pass(CadencePassAttribute(opt_level=1))
class RemoveBranchedQuantDequant(ExportPass):
"""
This pass looks for adjacent quant and dequant nodes with identical
parameters, where the quant node has other users in addition to the
dequant. The quant and dequant pair would be removed by the
FuseQuantDequantToRequantizePass if not for the multiple users. This pass
removes just the dequant node by connecting it to the quant's parent node
"""

quantize_op_packets: set[EdgeOpOverloadPacket] = {
exir_ops.edge.cadence.quantize_per_tensor,
exir_ops.edge.quantized_decomposed.quantize_per_tensor,
}
dequantize_op_packets: set[EdgeOpOverloadPacket] = {
exir_ops.edge.cadence.dequantize_per_tensor,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor,
}

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
self.remove_branched(
graph_module, self.quantize_op_packets, self.dequantize_op_packets
)
self.remove_branched(
graph_module, self.dequantize_op_packets, self.quantize_op_packets
)

graph_module.graph.eliminate_dead_code()
result = super().call(graph_module)
return result

def remove_branched(
self,
graph_module: torch.fx.GraphModule,
producer_pkts: set[EdgeOpOverloadPacket],
consumer_pkts: set[EdgeOpOverloadPacket],
) -> None:
for node in graph_module.graph.nodes:
if (
node.op != "call_function"
or not isinstance(node.target, EdgeOpOverload)
or get_edge_overload_packet(node.target) not in producer_pkts
):
continue

if len(node.users) < 2:
continue

for user in node.users:
if (
not isinstance(user.target, EdgeOpOverload)
or get_edge_overload_packet(user.target) not in consumer_pkts
):
continue

# check qparams match
if node.args[1:] != user.args[1:]:
continue

user.replace_all_uses_with(node.args[0])


# The following class consolidates functions to remove ops that are redundant
# in Jarvis. Currently, each function in this class iterates over each node of
# the graph module once. In future, we could consolidate them into a monolithic
Expand All @@ -765,4 +827,5 @@ class CadenceRemoveNops:
RemoveNopMulOpPass,
RemoveNopAddOpPass,
RemoveNopLinalgVectorNormOpPass,
RemoveBranchedQuantDequant,
]
5 changes: 2 additions & 3 deletions backends/cadence/aot/tests/test_fusion_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
FuseTransposeOpPairsPass,
)
from executorch.backends.cadence.aot.graph_builder import GraphBuilder
from executorch.backends.cadence.aot.pass_utils import count_node
from executorch.backends.cadence.aot.pass_utils import count_node, op_counts_match
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from torch import nn
Expand All @@ -32,8 +32,7 @@ def check_op_counts(
graph_module: torch.fx.GraphModule,
expected_op_counts: dict[EdgeOpOverload, int],
) -> None:
for op, count in expected_op_counts.items():
self.assertEqual(count_node(graph_module, op), count)
self.assertTrue(op_counts_match(graph_module, expected_op_counts))


class TestFusionPasses(TestFusionPassesBase):
Expand Down
34 changes: 33 additions & 1 deletion backends/cadence/aot/tests/test_remove_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
from executorch.backends.cadence.aot import compiler
from executorch.backends.cadence.aot.compiler import export_to_edge

from executorch.backends.cadence.aot.pass_utils import count_node
from executorch.backends.cadence.aot.pass_utils import count_node, op_counts_match
from executorch.backends.cadence.aot.quantizer.quantizer import CadenceDefaultQuantizer
from executorch.backends.cadence.aot.remove_ops import (
RemoveAliasCopyOpPass,
RemoveBranchedQuantDequant,
RemoveCloneOpPass,
RemoveContiguousOpPass,
RemoveDetachCopyPass,
Expand Down Expand Up @@ -709,3 +710,34 @@ def forward(self, x):
self.assertEqual(
count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 2
)

def test_remove_dequant_on_branch(self):
class M(torch.nn.Module):
def forward(self, x):
x = torch.abs(x)
x0 = torch.ops.quantized_decomposed.quantize_per_tensor(
x, 1.2, 3, 0, 127, torch.int8
)
x1 = torch.abs(x0)
y0 = torch.ops.quantized_decomposed.dequantize_per_tensor(
x0, 1.2, 3, 0, 127, torch.int8
)
y1 = y0.view(-1)
return x1, y1

inputs = torch.rand(1, 8, 4, 6)
model = M()
graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module

graph_module = RemoveBranchedQuantDequant()(graph_module).graph_module
self.assertTrue(
op_counts_match(
graph_module,
expected_op_counts={
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1,
# we expect the pass to remove the dequantize node
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0,
exir_ops.edge.aten.abs.default: 2,
},
)
)
Loading