Skip to content

Commit ba8f6de

Browse files
author
morelos
committed
Update on "[ET-VK][Ops] dequantize_per_token.default test setup"
Creating dequantize_per_token testing framework along with a reference implementation for testing Differential Revision: [D76267037](https://our.internmc.facebook.com/intern/diff/D76267037/) [ghstack-poisoned]
2 parents 334872a + 4ec67b0 commit ba8f6de

File tree

125 files changed

+3422
-259
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

125 files changed

+3422
-259
lines changed

backends/cadence/aot/fuse_ops.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -863,6 +863,76 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
863863
return result
864864

865865

866+
@register_cadence_pass(CadencePassAttribute(opt_level=1))
867+
class FuseMulTensorIntoQuantPass(ExportPass):
868+
"""
869+
Looks for the pattern where aten.mul.Tensor is followed by quant node.
870+
If found, updates the quant scale to reflect the multiplication and
871+
removes the mul node.
872+
"""
873+
874+
def attempt_fusion(
875+
self, graph_module: torch.fx.GraphModule, mul_node: torch.fx.Node
876+
) -> None:
877+
full_nodes = [
878+
arg
879+
for arg in mul_node.args
880+
if isinstance(arg, torch.fx.Node)
881+
and arg.target == exir_ops.edge.aten.full.default
882+
]
883+
884+
if len(full_nodes) != 1 or len(mul_node.users) != 1:
885+
return
886+
887+
full_node = full_nodes[0]
888+
mul_user = list(mul_node.users.keys())[0]
889+
890+
if mul_user.target not in {
891+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
892+
exir_ops.edge.cadence.quantize_per_tensor.default,
893+
}:
894+
return
895+
896+
quant_node = mul_user
897+
898+
# Calculate the new scale value.
899+
prev_scale = quant_node.args[1]
900+
assert isinstance(prev_scale, (int, float))
901+
mul_scalar = full_node.args[1]
902+
assert isinstance(mul_scalar, (int, float))
903+
new_scale = float(prev_scale) * float(mul_scalar)
904+
905+
logging.debug(
906+
f"Fused {mul_node} and {full_node} into {quant_node}. Updated scale from {quant_node.args[1]} to {new_scale}"
907+
)
908+
909+
# Replace the input first
910+
quant_node.replace_input_with(
911+
cast(torch.fx.Node, quant_node.args[0]),
912+
cast(torch.fx.Node, mul_node.args[0]),
913+
)
914+
915+
# Now update the scale in the args
916+
new_quant_args = list(quant_node.args)
917+
new_quant_args[1] = new_scale
918+
quant_node.args = tuple(new_quant_args)
919+
920+
# Clean up the mul_node
921+
mul_node.args = ()
922+
mul_node.users = {}
923+
924+
graph_module.graph.erase_node(mul_node)
925+
graph_module.graph.erase_node(full_node)
926+
927+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
928+
for node in graph_module.graph.find_nodes(
929+
op="call_function", target=exir_ops.edge.aten.mul.Tensor
930+
):
931+
self.attempt_fusion(graph_module, node)
932+
graph_module.graph.eliminate_dead_code()
933+
return super().call(graph_module)
934+
935+
866936
@register_cadence_pass(CadencePassAttribute(opt_level=1))
867937
class FuseMulTensorIntoDequantPass(ExportPass):
868938
"""

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
FuseMMWithAdd,
2121
FuseMulScalarIntoDequantPass,
2222
FuseMulTensorIntoDequantPass,
23+
FuseMulTensorIntoQuantPass,
2324
FuseQuantDequantToRequantizePass,
2425
FuseTransposeOrPermuteOpPairsPass,
2526
)
@@ -587,6 +588,48 @@ def test_fuse_mul_scalar_into_dequant(self):
587588
deq_scale = node.args[1]
588589
self.assertEqual(deq_scale, dequant_scale * mul_value)
589590

591+
def test_fuse_mul_into_quant(self):
592+
quant_scale = 1.5
593+
mul_value = 10
594+
595+
builder = GraphBuilder()
596+
x = builder.placeholder("x", torch.randn(4, 32, dtype=torch.float32))
597+
full = builder.call_operator(
598+
op=exir_ops.edge.aten.full.default,
599+
args=([1], mul_value),
600+
)
601+
mul = builder.call_operator(
602+
op=exir_ops.edge.aten.mul.Tensor,
603+
args=(x, full),
604+
)
605+
quant = builder.call_operator(
606+
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
607+
args=(mul, quant_scale, 0, 0, 255, torch.uint8),
608+
)
609+
builder.output(quant)
610+
graph_module = FuseMulTensorIntoQuantPass()(
611+
builder.get_graph_module()
612+
).graph_module
613+
614+
# verify that the mul and full ops were removed
615+
self.check_op_counts(
616+
graph_module,
617+
expected_op_counts={
618+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1,
619+
exir_ops.edge.aten.full.default: 0,
620+
exir_ops.edge.aten.mul.Tensor: 0,
621+
},
622+
)
623+
624+
# verify that the quant scale value was updated correctly
625+
for node in graph_module.graph.nodes:
626+
if (
627+
node.target
628+
== exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
629+
):
630+
deq_scale = node.args[1]
631+
self.assertEqual(deq_scale, quant_scale * mul_value)
632+
590633
def test_fuse_then_transpose_pass(self):
591634
# Create a graph with full -> transpose.
592635
builder = GraphBuilder()

backends/qualcomm/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ add_library(qnn_implementation STATIC)
130130
add_library(qnn_logger STATIC)
131131
add_library(qnn_manager STATIC)
132132
add_library(qnn_mem_manager STATIC)
133+
add_library(qnn_op_package_manager STATIC)
133134
add_library(qnn_profiler STATIC)
134135
add_library(qnn_schema INTERFACE ${_qnn_schema__outputs})
135136
add_library(qnn_sys_function_interface INTERFACE)
@@ -152,7 +153,7 @@ target_link_libraries(
152153
target_link_libraries(qnn_executorch_logging PRIVATE qnn_schema)
153154
target_link_libraries(qnn_profiler PRIVATE qnn_executorch_logging)
154155
target_link_libraries(qnn_logger PRIVATE qnn_implementation ${android_log})
155-
target_link_libraries(qnn_backend PRIVATE qnn_implementation qnn_logger)
156+
target_link_libraries(qnn_backend PRIVATE qnn_implementation qnn_logger qnn_op_package_manager)
156157
target_link_libraries(qnn_custom_protocol PRIVATE qnn_logger)
157158
target_link_libraries(
158159
qnn_device PRIVATE qnn_executorch_logging qnn_implementation qnn_logger

backends/qualcomm/builders/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,8 @@ import torch
176176
from executorch.backends.qualcomm.utils.constants import QCOM_DATA
177177
# op builder will inherit NodeVisitor and have its own implementation
178178
# register_node_visitor for book-keeping the dictionary of target name v.s. callback
179-
from .node_visitor import NodeVisitor, register_node_visitor
179+
from .node_visitor import NodeVisitor
180+
from .node_visitor_manager import register_node_visitor
180181
# the definitions required to build operator in QNN
181182
from .qnn_constants import OpLayerNorm, QNN_OP_PACKAGE_NAME_QTI_AISW
182183
# utility to get parameter value when creating tensor in QNN

backends/qualcomm/builders/node_visitor.py

Lines changed: 2 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,9 @@
6363
torch.int64: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_64,
6464
torch.uint8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_8,
6565
torch.uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_16,
66+
torch.uint32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
6667
float: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
68+
int: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
6769
}
6870

6971
PER_CHANNEL_ENCODING = {
@@ -470,51 +472,3 @@ def define_node(
470472
) -> PyQnnWrapper.PyQnnOpWrapper:
471473
"""Convert torch.fx.Node to OpWrapper"""
472474
raise NotImplementedError("NodeVisitor must be extended!")
473-
474-
475-
# This will hold mapping of all node names to the visitor class
476-
_node_visitor_dict = {}
477-
478-
479-
def register_node_visitor(visitor):
480-
"""Register node visitor into _node_visitor_dict"""
481-
assert (
482-
isinstance(visitor, type)
483-
and issubclass(visitor, NodeVisitor)
484-
and hasattr(visitor, "target")
485-
), f"Illformed NodeVisitor subclass, can't register!, got: {visitor}"
486-
for target in visitor.target:
487-
_node_visitor_dict[target] = visitor
488-
489-
490-
def generate_node_to_external_map(
491-
edge_program: torch.export.ExportedProgram,
492-
) -> Dict[torch.fx.Node, int]:
493-
node_to_external_map = {}
494-
for node in edge_program.graph_module.graph.nodes:
495-
# The order in which we visit the placeholder node is same as the *args
496-
# order for the forward(*args) signature for this gm. Using the order of
497-
# the nodes as external_id to extract the right arg from *args at runtime
498-
if is_graph_input(node, edge_program):
499-
node_to_external_map[node] = len(node_to_external_map)
500-
for node in edge_program.graph_module.graph.nodes:
501-
if is_graph_output(node):
502-
node_to_external_map[node] = len(node_to_external_map)
503-
return node_to_external_map
504-
505-
506-
def get_node_visitors(
507-
edge_program: torch.export.ExportedProgram,
508-
enable_tensor_dump=False,
509-
) -> Dict[str, NodeVisitor]:
510-
"""Create a new class instance at runtime, and put them in a dict"""
511-
node_to_external_map = generate_node_to_external_map(edge_program)
512-
node_visitors = {}
513-
for target, visitor in _node_visitor_dict.items():
514-
assert callable(
515-
visitor
516-
), f"Expeting a callable class, but got {visitor} of type {type(visitor)}"
517-
node_visitors[target] = visitor(
518-
node_to_external_map, edge_program, enable_tensor_dump
519-
)
520-
return node_visitors
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Dict, List
8+
9+
import torch
10+
from executorch.backends.qualcomm.serialization.qc_schema import (
11+
QnnExecuTorchOpPackageInfo,
12+
)
13+
14+
from .node_visitor import NodeVisitor
15+
from .op_custom_op import CustomOp
16+
from .utils import is_graph_input, is_graph_output
17+
18+
19+
# This will hold mapping of all node names to the visitor class
20+
_node_visitor_dict = {}
21+
22+
23+
def register_node_visitor(visitor):
24+
"""Register node visitor into _node_visitor_dict"""
25+
assert (
26+
isinstance(visitor, type)
27+
and issubclass(visitor, NodeVisitor)
28+
and hasattr(visitor, "target")
29+
), f"Informed NodeVisitor subclass, can't register!, got: {visitor}"
30+
for target in visitor.target:
31+
_node_visitor_dict[target] = visitor
32+
33+
34+
def generate_node_to_external_map(
35+
edge_program: torch.export.ExportedProgram,
36+
) -> Dict[torch.fx.Node, int]:
37+
node_to_external_map = {}
38+
for node in edge_program.graph_module.graph.nodes:
39+
# The order in which we visit the placeholder node is same as the *args
40+
# order for the forward(*args) signature for this gm. Using the order of
41+
# the nodes as external_id to extract the right arg from *args at runtime
42+
if is_graph_input(node, edge_program):
43+
node_to_external_map[node] = len(node_to_external_map)
44+
for node in edge_program.graph_module.graph.nodes:
45+
if is_graph_output(node):
46+
node_to_external_map[node] = len(node_to_external_map)
47+
return node_to_external_map
48+
49+
50+
def get_node_visitors(
51+
edge_program: torch.export.ExportedProgram,
52+
enable_tensor_dump=False,
53+
op_package_infos: List[QnnExecuTorchOpPackageInfo] = None,
54+
) -> Dict[str, NodeVisitor]:
55+
"""Create a new class instance at runtime, and put them in a dict"""
56+
node_to_external_map = generate_node_to_external_map(edge_program)
57+
node_visitors = {}
58+
for target, visitor in _node_visitor_dict.items():
59+
assert callable(
60+
visitor
61+
), f"Expecting a callable class, but got {visitor} of type {type(visitor)}"
62+
node_visitors[target] = visitor(
63+
node_to_external_map, edge_program, enable_tensor_dump
64+
)
65+
if op_package_infos:
66+
custom_ops = []
67+
for op_package_info in op_package_infos:
68+
if op_package_info.custom_op_name not in custom_ops:
69+
custom_op_builder = CustomOp(
70+
op_package_info,
71+
node_to_external_map,
72+
edge_program,
73+
enable_tensor_dump,
74+
)
75+
node_visitors[op_package_info.custom_op_name] = custom_op_builder
76+
custom_ops.append(op_package_info.custom_op_name)
77+
return node_visitors

backends/qualcomm/builders/op_abs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010
import torch
1111

12-
from .node_visitor import NodeVisitor, register_node_visitor
12+
from .node_visitor import NodeVisitor
13+
from .node_visitor_manager import register_node_visitor
1314
from .qnn_constants import OpElementWiseAbs, QNN_OP_PACKAGE_NAME_QTI_AISW
1415

1516

backends/qualcomm/builders/op_adaptive_avg_pool2d.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111

1212
import torch
1313

14-
from .node_visitor import NodeVisitor, register_node_visitor
14+
from .node_visitor import NodeVisitor
15+
from .node_visitor_manager import register_node_visitor
1516
from .qnn_constants import OpPoolAvg2d, QNN_OP_PACKAGE_NAME_QTI_AISW
1617

1718

backends/qualcomm/builders/op_add.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010
import torch
1111

12-
from .node_visitor import NodeVisitor, register_node_visitor
12+
from .node_visitor import NodeVisitor
13+
from .node_visitor_manager import register_node_visitor
1314
from .qnn_constants import OpElementWiseAdd, QNN_OP_PACKAGE_NAME_QTI_AISW
1415

1516

backends/qualcomm/builders/op_amax.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
import torch
1313
from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA
1414

15-
from .node_visitor import NodeVisitor, register_node_visitor
15+
from .node_visitor import NodeVisitor
16+
from .node_visitor_manager import register_node_visitor
1617
from .qnn_constants import OpReduceMax, QNN_OP_PACKAGE_NAME_QTI_AISW
1718

1819

backends/qualcomm/builders/op_and.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010
import torch
1111

12-
from .node_visitor import NodeVisitor, register_node_visitor
12+
from .node_visitor import NodeVisitor
13+
from .node_visitor_manager import register_node_visitor
1314
from .qnn_constants import OpElementWiseAnd, QNN_OP_PACKAGE_NAME_QTI_AISW
1415

1516

backends/qualcomm/builders/op_arange.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010
import torch
1111

12-
from .node_visitor import NodeVisitor, register_node_visitor
12+
from .node_visitor import NodeVisitor
13+
from .node_visitor_manager import register_node_visitor
1314

1415

1516
@register_node_visitor

backends/qualcomm/builders/op_argmin.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
import torch
1111
from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA
1212

13-
from .node_visitor import NodeVisitor, register_node_visitor
13+
from .node_visitor import NodeVisitor
14+
from .node_visitor_manager import register_node_visitor
1415
from .qnn_constants import OpArgmin, QNN_OP_PACKAGE_NAME_QTI_AISW
1516

1617

backends/qualcomm/builders/op_avg_pool2d.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
import torch
1313
from executorch.backends.qualcomm.utils.constants import QCOM_DATA
1414

15-
from .node_visitor import NodeVisitor, register_node_visitor
15+
from .node_visitor import NodeVisitor
16+
from .node_visitor_manager import register_node_visitor
1617
from .qnn_constants import OpPoolAvg2d, QNN_OP_PACKAGE_NAME_QTI_AISW
1718

1819

backends/qualcomm/builders/op_batch_norm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
)
1919
from executorch.exir.dialects._ops import ops as exir_ops
2020

21-
from .node_visitor import NodeVisitor, register_node_visitor
21+
from .node_visitor import NodeVisitor
22+
from .node_visitor_manager import register_node_visitor
2223
from .qnn_constants import OpBatchnorm, QNN_OP_PACKAGE_NAME_QTI_AISW
2324
from .utils import get_parameter
2425

0 commit comments

Comments
 (0)