Skip to content

Commit a81ea21

Browse files
committed
Qualcomm AI Engine Direct - Add 4-bit Embedding Quantization Option
Summary: - Introduce 4-bit embedding quantization for prefill, kv, and hybrid mode - Fixe an assertion condition bug in the annotate_and_quant_scalar pass - Refactor passes in capture_program - Add topological sorting for passes in capture_program
1 parent e00eaea commit a81ea21

File tree

16 files changed

+413
-151
lines changed

16 files changed

+413
-151
lines changed

backends/qualcomm/_passes/__init__.py

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from .annotate_and_quant_scalar import AnnotateAndQuantScalar
2+
from .annotate_decomposed import AnnotateDecomposed
3+
from .annotate_quant_attrs import AnnotateQuantAttrs
4+
from .convert_bmm_to_matmul import ConvertBmmToMatmul
5+
from .convert_interpolate_with_upsample2d import ConvertInterpolateWithUpsample2D
6+
from .convert_prelu import ConvertPReLU
7+
from .convert_to_linear import ConvertToLinear
8+
from .expand_broadcast_tensor_shape import ExpandBroadcastTensorShape
9+
from .fold_qdq import FoldQDQ
10+
from .i64_to_i32 import I64toI32
11+
from .layout_transform import LayoutTransform
12+
from .recompose_pixel_unshuffle import RecomposePixelUnshuffle
13+
from .recompose_rms_norm import RecomposeRmsNorm
14+
from .remove_redundancy import RemoveRedundancy
15+
from .replace_index_put_input import ReplaceIndexPutInput
16+
17+
18+
__all__ = [
19+
AnnotateAndQuantScalar,
20+
AnnotateDecomposed,
21+
AnnotateQuantAttrs,
22+
ConvertBmmToMatmul,
23+
ConvertInterpolateWithUpsample2D,
24+
ConvertPReLU,
25+
ConvertToLinear,
26+
ExpandBroadcastTensorShape,
27+
FoldQDQ,
28+
I64toI32,
29+
LayoutTransform,
30+
RecomposePixelUnshuffle,
31+
RecomposeRmsNorm,
32+
RemoveRedundancy,
33+
ReplaceIndexPutInput,
34+
]

backends/qualcomm/_passes/annotate_and_quant_scalar.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ def _get_source_scalar_node(self, node: torch.fx.Node) -> torch.fx.Node:
5353
if node.op == "placeholder":
5454
if not (shape := node.meta["val"].size()):
5555
return node
56-
assert f"The output of node {node} is not a scalar, but a tensor with shape {shape}"
56+
assert (
57+
not shape
58+
), f"The output of node {node} is not a scalar, but a tensor with shape {shape}"
5759
return self._get_source_scalar_node(node.args[0])
5860

5961
def _update_scalar_node_attrs(self, node: torch.fx.Node, quant_attrs: Dict) -> Dict:

backends/qualcomm/_passes/i64_to_i32.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6+
from typing import FrozenSet
7+
68
import torch
79
from executorch.backends.qualcomm.builders.utils import get_parameter, is_constant
810
from executorch.exir.dialects._ops import ops as exir_ops
@@ -15,9 +17,14 @@ class I64toI32(ExportPass):
1517
Cast unsupported int64 datatype into int32.
1618
"""
1719

18-
def __init__(self, edge_program: torch.export.ExportedProgram):
20+
def __init__(
21+
self,
22+
edge_program: torch.export.ExportedProgram,
23+
skip_node: FrozenSet[str] = frozenset(),
24+
):
1925
super(I64toI32, self).__init__()
2026
self.edge_program = edge_program
27+
self.skip_node = skip_node
2128
# pyre-ignore[4]
2229
self.copy_op = exir_ops.edge.aten._to_copy.default
2330

@@ -42,6 +49,8 @@ def _is_tensor_of_dtype(self, node_val, dtype: torch.dtype) -> bool:
4249

4350
def _cast_to_int32(self, graph_module: torch.fx.GraphModule):
4451
for n in graph_module.graph.nodes:
52+
if n.target in self.skip_node:
53+
continue
4554
if is_constant(n, self.edge_program):
4655
param = get_parameter(n, self.edge_program)
4756
if param.dtype == torch.int64:

backends/qualcomm/_passes/utils.py

+60
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,63 @@ def get_quant_attrs(
4343

4444
quant_attrs[QCOM_ENCODING] = quant_node.target
4545
return quant_attrs
46+
47+
48+
def get_passes_dependency_for_capture_program():
49+
"""
50+
This function records the dependencies for passes used in the capture_program.
51+
52+
It returns a dictionary where the keys are pass classes and the values are lists of
53+
dependencies required by each pass. This helps in managing and organizing the sequence
54+
of passes needed for the capture_program to function correctly.
55+
56+
Returns:
57+
dict: A dictionary mapping each pass to its corresponding list of dependencies.
58+
"""
59+
from executorch.backends.qualcomm._passes import (
60+
AnnotateAndQuantScalar,
61+
AnnotateDecomposed,
62+
AnnotateQuantAttrs,
63+
ConvertBmmToMatmul,
64+
ConvertInterpolateWithUpsample2D,
65+
ConvertPReLU,
66+
ConvertToLinear,
67+
ExpandBroadcastTensorShape,
68+
FoldQDQ,
69+
I64toI32,
70+
LayoutTransform,
71+
RecomposePixelUnshuffle,
72+
RecomposeRmsNorm,
73+
RemoveRedundancy,
74+
ReplaceIndexPutInput,
75+
)
76+
77+
return {
78+
RecomposePixelUnshuffle: [RemoveRedundancy],
79+
RecomposeRmsNorm: [RemoveRedundancy],
80+
ConvertToLinear: [RecomposePixelUnshuffle],
81+
ConvertPReLU: [RemoveRedundancy],
82+
ConvertBmmToMatmul: [ConvertToLinear],
83+
ConvertInterpolateWithUpsample2D: [RemoveRedundancy],
84+
I64toI32: [RemoveRedundancy],
85+
AnnotateQuantAttrs: [
86+
RecomposePixelUnshuffle,
87+
RecomposeRmsNorm,
88+
ConvertToLinear,
89+
ConvertPReLU,
90+
ConvertBmmToMatmul,
91+
ConvertInterpolateWithUpsample2D,
92+
],
93+
AnnotateAndQuantScalar: [
94+
AnnotateQuantAttrs,
95+
],
96+
AnnotateDecomposed: [RemoveRedundancy],
97+
FoldQDQ: [AnnotateQuantAttrs, AnnotateAndQuantScalar, AnnotateDecomposed],
98+
ExpandBroadcastTensorShape: [RemoveRedundancy],
99+
LayoutTransform: [
100+
AnnotateQuantAttrs,
101+
AnnotateAndQuantScalar,
102+
ExpandBroadcastTensorShape,
103+
],
104+
ReplaceIndexPutInput: [LayoutTransform],
105+
}

backends/qualcomm/scripts/build.sh

+3
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ if [ "$BUILD_AARCH64" = true ]; then
8787
-DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK_ROOT/build/cmake/android.toolchain.cmake \
8888
-DANDROID_ABI='arm64-v8a' \
8989
-DANDROID_NATIVE_API_LEVEL=23 \
90+
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
9091
-DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \
9192
-B$BUILD_ROOT
9293

@@ -101,6 +102,7 @@ if [ "$BUILD_AARCH64" = true ]; then
101102
-DANDROID_ABI='arm64-v8a' \
102103
-DANDROID_NATIVE_API_LEVEL=23 \
103104
-DCMAKE_PREFIX_PATH=$CMAKE_PREFIX_PATH \
105+
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
104106
-DCMAKE_FIND_ROOT_PATH_MODE_PACKAGE=BOTH \
105107
-DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \
106108
-B$EXAMPLE_ROOT
@@ -125,6 +127,7 @@ if [ "$BUILD_X86_64" = true ]; then
125127
-DEXECUTORCH_BUILD_QNN=ON \
126128
-DEXECUTORCH_BUILD_DEVTOOLS=ON \
127129
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
130+
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
128131
-DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
129132
-DEXECUTORCH_ENABLE_EVENT_TRACER=ON \
130133
-DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \

backends/qualcomm/utils/constants.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
QCOM_SCALE_OFFSET = "scale_offset"
2727
QCOM_ZERO_POINT = "zero_point"
2828
QCOM_ZERO_POINTS = "zero_points"
29-
QCOM_PASS_EXPAND_BROADCAST_SHAPE = "expand_broadcast_shape"
30-
QCOM_PASS_SKIP_ADVANCED_REQUANT = "skip_advanced_requant"
29+
QCOM_PASS_ACTIVATE_KEY = "activate"
30+
QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY = "args_kwargs_defaults"
3131

3232
# constants in backends/qualcomm/tests
3333
QCOM_ANNOTATION = "annotation"

0 commit comments

Comments
 (0)