@@ -863,6 +863,76 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
863
863
return result
864
864
865
865
866
+ @register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
867
+ class FuseMulTensorIntoQuantPass (ExportPass ):
868
+ """
869
+ Looks for the pattern where aten.mul.Tensor is followed by quant node.
870
+ If found, updates the quant scale to reflect the multiplication and
871
+ removes the mul node.
872
+ """
873
+
874
+ def attempt_fusion (
875
+ self , graph_module : torch .fx .GraphModule , mul_node : torch .fx .Node
876
+ ) -> None :
877
+ full_nodes = [
878
+ arg
879
+ for arg in mul_node .args
880
+ if isinstance (arg , torch .fx .Node )
881
+ and arg .target == exir_ops .edge .aten .full .default
882
+ ]
883
+
884
+ if len (full_nodes ) != 1 or len (mul_node .users ) != 1 :
885
+ return
886
+
887
+ full_node = full_nodes [0 ]
888
+ mul_user = list (mul_node .users .keys ())[0 ]
889
+
890
+ if mul_user .target not in {
891
+ exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
892
+ exir_ops .edge .cadence .quantize_per_tensor .default ,
893
+ }:
894
+ return
895
+
896
+ quant_node = mul_user
897
+
898
+ # Calculate the new scale value.
899
+ prev_scale = quant_node .args [1 ]
900
+ assert isinstance (prev_scale , (int , float ))
901
+ mul_scalar = full_node .args [1 ]
902
+ assert isinstance (mul_scalar , (int , float ))
903
+ new_scale = float (prev_scale ) * float (mul_scalar )
904
+
905
+ logging .debug (
906
+ f"Fused { mul_node } and { full_node } into { quant_node } . Updated scale from { quant_node .args [1 ]} to { new_scale } "
907
+ )
908
+
909
+ # Replace the input first
910
+ quant_node .replace_input_with (
911
+ cast (torch .fx .Node , quant_node .args [0 ]),
912
+ cast (torch .fx .Node , mul_node .args [0 ]),
913
+ )
914
+
915
+ # Now update the scale in the args
916
+ new_quant_args = list (quant_node .args )
917
+ new_quant_args [1 ] = new_scale
918
+ quant_node .args = tuple (new_quant_args )
919
+
920
+ # Clean up the mul_node
921
+ mul_node .args = ()
922
+ mul_node .users = {}
923
+
924
+ graph_module .graph .erase_node (mul_node )
925
+ graph_module .graph .erase_node (full_node )
926
+
927
+ def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
928
+ for node in graph_module .graph .find_nodes (
929
+ op = "call_function" , target = exir_ops .edge .aten .mul .Tensor
930
+ ):
931
+ self .attempt_fusion (graph_module , node )
932
+ graph_module .graph .eliminate_dead_code ()
933
+ return super ().call (graph_module )
934
+
935
+
866
936
@register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
867
937
class FuseMulTensorIntoDequantPass (ExportPass ):
868
938
"""
0 commit comments