Skip to content

Qualcomm AI Engine Direct - Add 4-bit Embedding Quantization Option #7691

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions backends/qualcomm/_passes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from .annotate_and_quant_scalar import AnnotateAndQuantScalar
from .annotate_decomposed import AnnotateDecomposed
from .annotate_quant_attrs import AnnotateQuantAttrs
from .convert_bmm_to_matmul import ConvertBmmToMatmul
from .convert_interpolate_with_upsample2d import ConvertInterpolateWithUpsample2D
from .convert_prelu import ConvertPReLU
from .convert_to_linear import ConvertToLinear
from .expand_broadcast_tensor_shape import ExpandBroadcastTensorShape
from .fold_qdq import FoldQDQ
from .i64_to_i32 import I64toI32
from .layout_transform import LayoutTransform
from .recompose_pixel_unshuffle import RecomposePixelUnshuffle
from .recompose_rms_norm import RecomposeRmsNorm
from .remove_redundancy import RemoveRedundancy
from .replace_index_put_input import ReplaceIndexPutInput


__all__ = [
AnnotateAndQuantScalar,
AnnotateDecomposed,
AnnotateQuantAttrs,
ConvertBmmToMatmul,
ConvertInterpolateWithUpsample2D,
ConvertPReLU,
ConvertToLinear,
ExpandBroadcastTensorShape,
FoldQDQ,
I64toI32,
LayoutTransform,
RecomposePixelUnshuffle,
RecomposeRmsNorm,
RemoveRedundancy,
ReplaceIndexPutInput,
]
4 changes: 3 additions & 1 deletion backends/qualcomm/_passes/annotate_and_quant_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ def _get_source_scalar_node(self, node: torch.fx.Node) -> torch.fx.Node:
if node.op == "placeholder":
if not (shape := node.meta["val"].size()):
return node
assert f"The output of node {node} is not a scalar, but a tensor with shape {shape}"
assert (
not shape
), f"The output of node {node} is not a scalar, but a tensor with shape {shape}"
return self._get_source_scalar_node(node.args[0])

def _update_scalar_node_attrs(self, node: torch.fx.Node, quant_attrs: Dict) -> Dict:
Expand Down
11 changes: 10 additions & 1 deletion backends/qualcomm/_passes/i64_to_i32.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import FrozenSet

import torch
from executorch.backends.qualcomm.builders.utils import get_parameter, is_constant
from executorch.exir.dialects._ops import ops as exir_ops
Expand All @@ -15,9 +17,14 @@ class I64toI32(ExportPass):
Cast unsupported int64 datatype into int32.
"""

def __init__(self, edge_program: torch.export.ExportedProgram):
def __init__(
self,
edge_program: torch.export.ExportedProgram,
skip_node: FrozenSet[str] = frozenset(),
):
super(I64toI32, self).__init__()
self.edge_program = edge_program
self.skip_node = skip_node
# pyre-ignore[4]
self.copy_op = exir_ops.edge.aten._to_copy.default

Expand All @@ -42,6 +49,8 @@ def _is_tensor_of_dtype(self, node_val, dtype: torch.dtype) -> bool:

def _cast_to_int32(self, graph_module: torch.fx.GraphModule):
for n in graph_module.graph.nodes:
if n.target in self.skip_node:
continue
if is_constant(n, self.edge_program):
param = get_parameter(n, self.edge_program)
if param.dtype == torch.int64:
Expand Down
60 changes: 60 additions & 0 deletions backends/qualcomm/_passes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,63 @@ def get_quant_attrs(

quant_attrs[QCOM_ENCODING] = quant_node.target
return quant_attrs


def get_passes_dependency_for_capture_program():
"""
This function records the dependencies for passes used in the capture_program.

It returns a dictionary where the keys are pass classes and the values are lists of
dependencies required by each pass. This helps in managing and organizing the sequence
of passes needed for the capture_program to function correctly.

Returns:
dict: A dictionary mapping each pass to its corresponding list of dependencies.
"""
from executorch.backends.qualcomm._passes import (
AnnotateAndQuantScalar,
AnnotateDecomposed,
AnnotateQuantAttrs,
ConvertBmmToMatmul,
ConvertInterpolateWithUpsample2D,
ConvertPReLU,
ConvertToLinear,
ExpandBroadcastTensorShape,
FoldQDQ,
I64toI32,
LayoutTransform,
RecomposePixelUnshuffle,
RecomposeRmsNorm,
RemoveRedundancy,
ReplaceIndexPutInput,
)

return {
RecomposePixelUnshuffle: [RemoveRedundancy],
RecomposeRmsNorm: [RemoveRedundancy],
ConvertToLinear: [RecomposePixelUnshuffle],
ConvertPReLU: [RemoveRedundancy],
ConvertBmmToMatmul: [ConvertToLinear],
ConvertInterpolateWithUpsample2D: [RemoveRedundancy],
I64toI32: [RemoveRedundancy],
AnnotateQuantAttrs: [
RecomposePixelUnshuffle,
RecomposeRmsNorm,
ConvertToLinear,
ConvertPReLU,
ConvertBmmToMatmul,
ConvertInterpolateWithUpsample2D,
],
AnnotateAndQuantScalar: [
AnnotateQuantAttrs,
],
AnnotateDecomposed: [RemoveRedundancy],
FoldQDQ: [AnnotateQuantAttrs, AnnotateAndQuantScalar, AnnotateDecomposed],
ExpandBroadcastTensorShape: [RemoveRedundancy],
LayoutTransform: [
AnnotateQuantAttrs,
AnnotateAndQuantScalar,
ExpandBroadcastTensorShape,
],
ReplaceIndexPutInput: [LayoutTransform],
}
3 changes: 3 additions & 0 deletions backends/qualcomm/scripts/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ if [ "$BUILD_AARCH64" = true ]; then
-DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK_ROOT/build/cmake/android.toolchain.cmake \
-DANDROID_ABI='arm64-v8a' \
-DANDROID_NATIVE_API_LEVEL=23 \
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
-DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \
-B$BUILD_ROOT

Expand All @@ -101,6 +102,7 @@ if [ "$BUILD_AARCH64" = true ]; then
-DANDROID_ABI='arm64-v8a' \
-DANDROID_NATIVE_API_LEVEL=23 \
-DCMAKE_PREFIX_PATH=$CMAKE_PREFIX_PATH \
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
-DCMAKE_FIND_ROOT_PATH_MODE_PACKAGE=BOTH \
-DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \
-B$EXAMPLE_ROOT
Expand All @@ -125,6 +127,7 @@ if [ "$BUILD_X86_64" = true ]; then
-DEXECUTORCH_BUILD_QNN=ON \
-DEXECUTORCH_BUILD_DEVTOOLS=ON \
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
-DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
-DEXECUTORCH_ENABLE_EVENT_TRACER=ON \
-DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE \
Expand Down
4 changes: 2 additions & 2 deletions backends/qualcomm/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
QCOM_SCALE_OFFSET = "scale_offset"
QCOM_ZERO_POINT = "zero_point"
QCOM_ZERO_POINTS = "zero_points"
QCOM_PASS_EXPAND_BROADCAST_SHAPE = "expand_broadcast_shape"
QCOM_PASS_SKIP_ADVANCED_REQUANT = "skip_advanced_requant"
QCOM_PASS_ACTIVATE_KEY = "activate"
QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY = "args_kwargs_defaults"

# constants in backends/qualcomm/tests
QCOM_ANNOTATION = "annotation"
Expand Down
Loading
Loading