Skip to content

Commit 8cfa858

Browse files
authored
[cadence][aot]Implement mul.Tensor to quant fusion.
Differential Revision: D76302365 Pull Request resolved: #11580
1 parent 4c2267b commit 8cfa858

File tree

2 files changed

+113
-0
lines changed

2 files changed

+113
-0
lines changed

backends/cadence/aot/fuse_ops.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -863,6 +863,76 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
863863
return result
864864

865865

866+
@register_cadence_pass(CadencePassAttribute(opt_level=1))
867+
class FuseMulTensorIntoQuantPass(ExportPass):
868+
"""
869+
Looks for the pattern where aten.mul.Tensor is followed by quant node.
870+
If found, updates the quant scale to reflect the multiplication and
871+
removes the mul node.
872+
"""
873+
874+
def attempt_fusion(
875+
self, graph_module: torch.fx.GraphModule, mul_node: torch.fx.Node
876+
) -> None:
877+
full_nodes = [
878+
arg
879+
for arg in mul_node.args
880+
if isinstance(arg, torch.fx.Node)
881+
and arg.target == exir_ops.edge.aten.full.default
882+
]
883+
884+
if len(full_nodes) != 1 or len(mul_node.users) != 1:
885+
return
886+
887+
full_node = full_nodes[0]
888+
mul_user = list(mul_node.users.keys())[0]
889+
890+
if mul_user.target not in {
891+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
892+
exir_ops.edge.cadence.quantize_per_tensor.default,
893+
}:
894+
return
895+
896+
quant_node = mul_user
897+
898+
# Calculate the new scale value.
899+
prev_scale = quant_node.args[1]
900+
assert isinstance(prev_scale, (int, float))
901+
mul_scalar = full_node.args[1]
902+
assert isinstance(mul_scalar, (int, float))
903+
new_scale = float(prev_scale) * float(mul_scalar)
904+
905+
logging.debug(
906+
f"Fused {mul_node} and {full_node} into {quant_node}. Updated scale from {quant_node.args[1]} to {new_scale}"
907+
)
908+
909+
# Replace the input first
910+
quant_node.replace_input_with(
911+
cast(torch.fx.Node, quant_node.args[0]),
912+
cast(torch.fx.Node, mul_node.args[0]),
913+
)
914+
915+
# Now update the scale in the args
916+
new_quant_args = list(quant_node.args)
917+
new_quant_args[1] = new_scale
918+
quant_node.args = tuple(new_quant_args)
919+
920+
# Clean up the mul_node
921+
mul_node.args = ()
922+
mul_node.users = {}
923+
924+
graph_module.graph.erase_node(mul_node)
925+
graph_module.graph.erase_node(full_node)
926+
927+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
928+
for node in graph_module.graph.find_nodes(
929+
op="call_function", target=exir_ops.edge.aten.mul.Tensor
930+
):
931+
self.attempt_fusion(graph_module, node)
932+
graph_module.graph.eliminate_dead_code()
933+
return super().call(graph_module)
934+
935+
866936
@register_cadence_pass(CadencePassAttribute(opt_level=1))
867937
class FuseMulTensorIntoDequantPass(ExportPass):
868938
"""

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
FuseMMWithAdd,
2121
FuseMulScalarIntoDequantPass,
2222
FuseMulTensorIntoDequantPass,
23+
FuseMulTensorIntoQuantPass,
2324
FuseQuantDequantToRequantizePass,
2425
FuseTransposeOrPermuteOpPairsPass,
2526
)
@@ -587,6 +588,48 @@ def test_fuse_mul_scalar_into_dequant(self):
587588
deq_scale = node.args[1]
588589
self.assertEqual(deq_scale, dequant_scale * mul_value)
589590

591+
def test_fuse_mul_into_quant(self):
592+
quant_scale = 1.5
593+
mul_value = 10
594+
595+
builder = GraphBuilder()
596+
x = builder.placeholder("x", torch.randn(4, 32, dtype=torch.float32))
597+
full = builder.call_operator(
598+
op=exir_ops.edge.aten.full.default,
599+
args=([1], mul_value),
600+
)
601+
mul = builder.call_operator(
602+
op=exir_ops.edge.aten.mul.Tensor,
603+
args=(x, full),
604+
)
605+
quant = builder.call_operator(
606+
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
607+
args=(mul, quant_scale, 0, 0, 255, torch.uint8),
608+
)
609+
builder.output(quant)
610+
graph_module = FuseMulTensorIntoQuantPass()(
611+
builder.get_graph_module()
612+
).graph_module
613+
614+
# verify that the mul and full ops were removed
615+
self.check_op_counts(
616+
graph_module,
617+
expected_op_counts={
618+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1,
619+
exir_ops.edge.aten.full.default: 0,
620+
exir_ops.edge.aten.mul.Tensor: 0,
621+
},
622+
)
623+
624+
# verify that the quant scale value was updated correctly
625+
for node in graph_module.graph.nodes:
626+
if (
627+
node.target
628+
== exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
629+
):
630+
deq_scale = node.args[1]
631+
self.assertEqual(deq_scale, quant_scale * mul_value)
632+
590633
def test_fuse_then_transpose_pass(self):
591634
# Create a graph with full -> transpose.
592635
builder = GraphBuilder()

0 commit comments

Comments
 (0)