Skip to content

Commit 19fd52a

Browse files
author
morelos
committed
Update on "[ET-VK] double, short, and uint16 dtype runtime support"
Creating support for double, short, and uint16 for quantization ops. Registering the short keyword since theres already support. Also changing the cpu implementation to support half Differential Revision: [D75959063](https://our.internmc.facebook.com/intern/diff/D75959063/) [ghstack-poisoned]
2 parents 7f85daf + ccce41b commit 19fd52a

File tree

125 files changed

+3422
-259
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

125 files changed

+3422
-259
lines changed

backends/cadence/aot/fuse_ops.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -863,6 +863,76 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
863863
return result
864864

865865

866+
@register_cadence_pass(CadencePassAttribute(opt_level=1))
867+
class FuseMulTensorIntoQuantPass(ExportPass):
868+
"""
869+
Looks for the pattern where aten.mul.Tensor is followed by quant node.
870+
If found, updates the quant scale to reflect the multiplication and
871+
removes the mul node.
872+
"""
873+
874+
def attempt_fusion(
875+
self, graph_module: torch.fx.GraphModule, mul_node: torch.fx.Node
876+
) -> None:
877+
full_nodes = [
878+
arg
879+
for arg in mul_node.args
880+
if isinstance(arg, torch.fx.Node)
881+
and arg.target == exir_ops.edge.aten.full.default
882+
]
883+
884+
if len(full_nodes) != 1 or len(mul_node.users) != 1:
885+
return
886+
887+
full_node = full_nodes[0]
888+
mul_user = list(mul_node.users.keys())[0]
889+
890+
if mul_user.target not in {
891+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
892+
exir_ops.edge.cadence.quantize_per_tensor.default,
893+
}:
894+
return
895+
896+
quant_node = mul_user
897+
898+
# Calculate the new scale value.
899+
prev_scale = quant_node.args[1]
900+
assert isinstance(prev_scale, (int, float))
901+
mul_scalar = full_node.args[1]
902+
assert isinstance(mul_scalar, (int, float))
903+
new_scale = float(prev_scale) * float(mul_scalar)
904+
905+
logging.debug(
906+
f"Fused {mul_node} and {full_node} into {quant_node}. Updated scale from {quant_node.args[1]} to {new_scale}"
907+
)
908+
909+
# Replace the input first
910+
quant_node.replace_input_with(
911+
cast(torch.fx.Node, quant_node.args[0]),
912+
cast(torch.fx.Node, mul_node.args[0]),
913+
)
914+
915+
# Now update the scale in the args
916+
new_quant_args = list(quant_node.args)
917+
new_quant_args[1] = new_scale
918+
quant_node.args = tuple(new_quant_args)
919+
920+
# Clean up the mul_node
921+
mul_node.args = ()
922+
mul_node.users = {}
923+
924+
graph_module.graph.erase_node(mul_node)
925+
graph_module.graph.erase_node(full_node)
926+
927+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
928+
for node in graph_module.graph.find_nodes(
929+
op="call_function", target=exir_ops.edge.aten.mul.Tensor
930+
):
931+
self.attempt_fusion(graph_module, node)
932+
graph_module.graph.eliminate_dead_code()
933+
return super().call(graph_module)
934+
935+
866936
@register_cadence_pass(CadencePassAttribute(opt_level=1))
867937
class FuseMulTensorIntoDequantPass(ExportPass):
868938
"""

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
FuseMMWithAdd,
2121
FuseMulScalarIntoDequantPass,
2222
FuseMulTensorIntoDequantPass,
23+
FuseMulTensorIntoQuantPass,
2324
FuseQuantDequantToRequantizePass,
2425
FuseTransposeOrPermuteOpPairsPass,
2526
)
@@ -587,6 +588,48 @@ def test_fuse_mul_scalar_into_dequant(self):
587588
deq_scale = node.args[1]
588589
self.assertEqual(deq_scale, dequant_scale * mul_value)
589590

591+
def test_fuse_mul_into_quant(self):
592+
quant_scale = 1.5
593+
mul_value = 10
594+
595+
builder = GraphBuilder()
596+
x = builder.placeholder("x", torch.randn(4, 32, dtype=torch.float32))
597+
full = builder.call_operator(
598+
op=exir_ops.edge.aten.full.default,
599+
args=([1], mul_value),
600+
)
601+
mul = builder.call_operator(
602+
op=exir_ops.edge.aten.mul.Tensor,
603+
args=(x, full),
604+
)
605+
quant = builder.call_operator(
606+
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
607+
args=(mul, quant_scale, 0, 0, 255, torch.uint8),
608+
)
609+
builder.output(quant)
610+
graph_module = FuseMulTensorIntoQuantPass()(
611+
builder.get_graph_module()
612+
).graph_module
613+
614+
# verify that the mul and full ops were removed
615+
self.check_op_counts(
616+
graph_module,
617+
expected_op_counts={
618+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1,
619+
exir_ops.edge.aten.full.default: 0,
620+
exir_ops.edge.aten.mul.Tensor: 0,
621+
},
622+
)
623+
624+
# verify that the quant scale value was updated correctly
625+
for node in graph_module.graph.nodes:
626+
if (
627+
node.target
628+
== exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
629+
):
630+
deq_scale = node.args[1]
631+
self.assertEqual(deq_scale, quant_scale * mul_value)
632+
590633
def test_fuse_then_transpose_pass(self):
591634
# Create a graph with full -> transpose.
592635
builder = GraphBuilder()

backends/qualcomm/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ add_library(qnn_implementation STATIC)
130130
add_library(qnn_logger STATIC)
131131
add_library(qnn_manager STATIC)
132132
add_library(qnn_mem_manager STATIC)
133+
add_library(qnn_op_package_manager STATIC)
133134
add_library(qnn_profiler STATIC)
134135
add_library(qnn_schema INTERFACE ${_qnn_schema__outputs})
135136
add_library(qnn_sys_function_interface INTERFACE)
@@ -152,7 +153,7 @@ target_link_libraries(
152153
target_link_libraries(qnn_executorch_logging PRIVATE qnn_schema)
153154
target_link_libraries(qnn_profiler PRIVATE qnn_executorch_logging)
154155
target_link_libraries(qnn_logger PRIVATE qnn_implementation ${android_log})
155-
target_link_libraries(qnn_backend PRIVATE qnn_implementation qnn_logger)
156+
target_link_libraries(qnn_backend PRIVATE qnn_implementation qnn_logger qnn_op_package_manager)
156157
target_link_libraries(qnn_custom_protocol PRIVATE qnn_logger)
157158
target_link_libraries(
158159
qnn_device PRIVATE qnn_executorch_logging qnn_implementation qnn_logger

backends/qualcomm/builders/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,8 @@ import torch
176176
from executorch.backends.qualcomm.utils.constants import QCOM_DATA
177177
# op builder will inherit NodeVisitor and have its own implementation
178178
# register_node_visitor for book-keeping the dictionary of target name v.s. callback
179-
from .node_visitor import NodeVisitor, register_node_visitor
179+
from .node_visitor import NodeVisitor
180+
from .node_visitor_manager import register_node_visitor
180181
# the definitions required to build operator in QNN
181182
from .qnn_constants import OpLayerNorm, QNN_OP_PACKAGE_NAME_QTI_AISW
182183
# utility to get parameter value when creating tensor in QNN

backends/qualcomm/builders/node_visitor.py

Lines changed: 2 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,9 @@
6363
torch.int64: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_64,
6464
torch.uint8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_8,
6565
torch.uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_16,
66+
torch.uint32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
6667
float: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
68+
int: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
6769
}
6870

6971
PER_CHANNEL_ENCODING = {
@@ -470,51 +472,3 @@ def define_node(
470472
) -> PyQnnWrapper.PyQnnOpWrapper:
471473
"""Convert torch.fx.Node to OpWrapper"""
472474
raise NotImplementedError("NodeVisitor must be extended!")
473-
474-
475-
# This will hold mapping of all node names to the visitor class
476-
_node_visitor_dict = {}
477-
478-
479-
def register_node_visitor(visitor):
480-
"""Register node visitor into _node_visitor_dict"""
481-
assert (
482-
isinstance(visitor, type)
483-
and issubclass(visitor, NodeVisitor)
484-
and hasattr(visitor, "target")
485-
), f"Illformed NodeVisitor subclass, can't register!, got: {visitor}"
486-
for target in visitor.target:
487-
_node_visitor_dict[target] = visitor
488-
489-
490-
def generate_node_to_external_map(
491-
edge_program: torch.export.ExportedProgram,
492-
) -> Dict[torch.fx.Node, int]:
493-
node_to_external_map = {}
494-
for node in edge_program.graph_module.graph.nodes:
495-
# The order in which we visit the placeholder node is same as the *args
496-
# order for the forward(*args) signature for this gm. Using the order of
497-
# the nodes as external_id to extract the right arg from *args at runtime
498-
if is_graph_input(node, edge_program):
499-
node_to_external_map[node] = len(node_to_external_map)
500-
for node in edge_program.graph_module.graph.nodes:
501-
if is_graph_output(node):
502-
node_to_external_map[node] = len(node_to_external_map)
503-
return node_to_external_map
504-
505-
506-
def get_node_visitors(
507-
edge_program: torch.export.ExportedProgram,
508-
enable_tensor_dump=False,
509-
) -> Dict[str, NodeVisitor]:
510-
"""Create a new class instance at runtime, and put them in a dict"""
511-
node_to_external_map = generate_node_to_external_map(edge_program)
512-
node_visitors = {}
513-
for target, visitor in _node_visitor_dict.items():
514-
assert callable(
515-
visitor
516-
), f"Expeting a callable class, but got {visitor} of type {type(visitor)}"
517-
node_visitors[target] = visitor(
518-
node_to_external_map, edge_program, enable_tensor_dump
519-
)
520-
return node_visitors
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Dict, List
8+
9+
import torch
10+
from executorch.backends.qualcomm.serialization.qc_schema import (
11+
QnnExecuTorchOpPackageInfo,
12+
)
13+
14+
from .node_visitor import NodeVisitor
15+
from .op_custom_op import CustomOp
16+
from .utils import is_graph_input, is_graph_output
17+
18+
19+
# This will hold mapping of all node names to the visitor class
20+
_node_visitor_dict = {}
21+
22+
23+
def register_node_visitor(visitor):
24+
"""Register node visitor into _node_visitor_dict"""
25+
assert (
26+
isinstance(visitor, type)
27+
and issubclass(visitor, NodeVisitor)
28+
and hasattr(visitor, "target")
29+
), f"Informed NodeVisitor subclass, can't register!, got: {visitor}"
30+
for target in visitor.target:
31+
_node_visitor_dict[target] = visitor
32+
33+
34+
def generate_node_to_external_map(
35+
edge_program: torch.export.ExportedProgram,
36+
) -> Dict[torch.fx.Node, int]:
37+
node_to_external_map = {}
38+
for node in edge_program.graph_module.graph.nodes:
39+
# The order in which we visit the placeholder node is same as the *args
40+
# order for the forward(*args) signature for this gm. Using the order of
41+
# the nodes as external_id to extract the right arg from *args at runtime
42+
if is_graph_input(node, edge_program):
43+
node_to_external_map[node] = len(node_to_external_map)
44+
for node in edge_program.graph_module.graph.nodes:
45+
if is_graph_output(node):
46+
node_to_external_map[node] = len(node_to_external_map)
47+
return node_to_external_map
48+
49+
50+
def get_node_visitors(
51+
edge_program: torch.export.ExportedProgram,
52+
enable_tensor_dump=False,
53+
op_package_infos: List[QnnExecuTorchOpPackageInfo] = None,
54+
) -> Dict[str, NodeVisitor]:
55+
"""Create a new class instance at runtime, and put them in a dict"""
56+
node_to_external_map = generate_node_to_external_map(edge_program)
57+
node_visitors = {}
58+
for target, visitor in _node_visitor_dict.items():
59+
assert callable(
60+
visitor
61+
), f"Expecting a callable class, but got {visitor} of type {type(visitor)}"
62+
node_visitors[target] = visitor(
63+
node_to_external_map, edge_program, enable_tensor_dump
64+
)
65+
if op_package_infos:
66+
custom_ops = []
67+
for op_package_info in op_package_infos:
68+
if op_package_info.custom_op_name not in custom_ops:
69+
custom_op_builder = CustomOp(
70+
op_package_info,
71+
node_to_external_map,
72+
edge_program,
73+
enable_tensor_dump,
74+
)
75+
node_visitors[op_package_info.custom_op_name] = custom_op_builder
76+
custom_ops.append(op_package_info.custom_op_name)
77+
return node_visitors

backends/qualcomm/builders/op_abs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010
import torch
1111

12-
from .node_visitor import NodeVisitor, register_node_visitor
12+
from .node_visitor import NodeVisitor
13+
from .node_visitor_manager import register_node_visitor
1314
from .qnn_constants import OpElementWiseAbs, QNN_OP_PACKAGE_NAME_QTI_AISW
1415

1516

backends/qualcomm/builders/op_adaptive_avg_pool2d.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111

1212
import torch
1313

14-
from .node_visitor import NodeVisitor, register_node_visitor
14+
from .node_visitor import NodeVisitor
15+
from .node_visitor_manager import register_node_visitor
1516
from .qnn_constants import OpPoolAvg2d, QNN_OP_PACKAGE_NAME_QTI_AISW
1617

1718

backends/qualcomm/builders/op_add.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010
import torch
1111

12-
from .node_visitor import NodeVisitor, register_node_visitor
12+
from .node_visitor import NodeVisitor
13+
from .node_visitor_manager import register_node_visitor
1314
from .qnn_constants import OpElementWiseAdd, QNN_OP_PACKAGE_NAME_QTI_AISW
1415

1516

backends/qualcomm/builders/op_amax.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
import torch
1313
from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA
1414

15-
from .node_visitor import NodeVisitor, register_node_visitor
15+
from .node_visitor import NodeVisitor
16+
from .node_visitor_manager import register_node_visitor
1617
from .qnn_constants import OpReduceMax, QNN_OP_PACKAGE_NAME_QTI_AISW
1718

1819

backends/qualcomm/builders/op_and.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010
import torch
1111

12-
from .node_visitor import NodeVisitor, register_node_visitor
12+
from .node_visitor import NodeVisitor
13+
from .node_visitor_manager import register_node_visitor
1314
from .qnn_constants import OpElementWiseAnd, QNN_OP_PACKAGE_NAME_QTI_AISW
1415

1516

backends/qualcomm/builders/op_arange.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010
import torch
1111

12-
from .node_visitor import NodeVisitor, register_node_visitor
12+
from .node_visitor import NodeVisitor
13+
from .node_visitor_manager import register_node_visitor
1314

1415

1516
@register_node_visitor

backends/qualcomm/builders/op_argmin.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
import torch
1111
from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA
1212

13-
from .node_visitor import NodeVisitor, register_node_visitor
13+
from .node_visitor import NodeVisitor
14+
from .node_visitor_manager import register_node_visitor
1415
from .qnn_constants import OpArgmin, QNN_OP_PACKAGE_NAME_QTI_AISW
1516

1617

backends/qualcomm/builders/op_avg_pool2d.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
import torch
1313
from executorch.backends.qualcomm.utils.constants import QCOM_DATA
1414

15-
from .node_visitor import NodeVisitor, register_node_visitor
15+
from .node_visitor import NodeVisitor
16+
from .node_visitor_manager import register_node_visitor
1617
from .qnn_constants import OpPoolAvg2d, QNN_OP_PACKAGE_NAME_QTI_AISW
1718

1819

backends/qualcomm/builders/op_batch_norm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
)
1919
from executorch.exir.dialects._ops import ops as exir_ops
2020

21-
from .node_visitor import NodeVisitor, register_node_visitor
21+
from .node_visitor import NodeVisitor
22+
from .node_visitor_manager import register_node_visitor
2223
from .qnn_constants import OpBatchnorm, QNN_OP_PACKAGE_NAME_QTI_AISW
2324
from .utils import get_parameter
2425

0 commit comments

Comments
 (0)