Skip to content

Qualcomm AI Engine Direct - Fix Argmin #8308

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions backends/qualcomm/_passes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
from .annotate_and_quant_scalar import AnnotateAndQuantScalar
from .annotate_decomposed import AnnotateDecomposed
from .annotate_quant_attrs import AnnotateQuantAttrs
from .constant_i64_to_i32 import ConstantI64toI32
from .convert_bmm_to_matmul import ConvertBmmToMatmul
from .convert_interpolate_with_upsample2d import ConvertInterpolateWithUpsample2D
from .convert_prelu import ConvertPReLU
from .convert_to_linear import ConvertToLinear
from .expand_broadcast_tensor_shape import ExpandBroadcastTensorShape
from .fold_qdq import FoldQDQ
from .i64_to_i32 import I64toI32
from .layout_transform import LayoutTransform
from .recompose_pixel_unshuffle import RecomposePixelUnshuffle
from .recompose_rms_norm import RecomposeRmsNorm
from .remove_redundancy import RemoveRedundancy
from .replace_index_put_input import ReplaceIndexPutInput
from .tensor_i64_to_i32 import TensorI64toI32


__all__ = [
Expand All @@ -25,7 +26,8 @@
ConvertToLinear,
ExpandBroadcastTensorShape,
FoldQDQ,
I64toI32,
ConstantI64toI32,
TensorI64toI32,
LayoutTransform,
RecomposePixelUnshuffle,
RecomposeRmsNorm,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,18 @@
from torch._subclasses.fake_tensor import FakeTensor


class I64toI32(ExportPass):
class ConstantI64toI32(ExportPass):
"""
Cast unsupported int64 datatype into int32.
This will only be applied on constant nodes such as weights.
"""

def __init__(
self,
edge_program: torch.export.ExportedProgram,
skip_node: FrozenSet[str] = frozenset(),
):
super(I64toI32, self).__init__()
super(ConstantI64toI32, self).__init__()
self.edge_program = edge_program
self.skip_node = skip_node
# pyre-ignore[4]
Expand Down
1 change: 0 additions & 1 deletion backends/qualcomm/_passes/layout_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ class LayoutTransform(ExportPass):
layout_agnostic_ops = {
exir_ops.edge.aten.abs.default,
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.argmin.default,
exir_ops.edge.aten.bmm.default,
exir_ops.edge.aten.cat.default,
exir_ops.edge.aten.ceil.default,
Expand Down
128 changes: 128 additions & 0 deletions backends/qualcomm/_passes/tensor_i64_to_i32.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging

import torch
from executorch.backends.qualcomm.builders.utils import is_graph_output
from executorch.backends.qualcomm.utils.constants import QCOM_ORIG_DTYPE
from executorch.exir import ExirExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.program._program import _get_updated_graph_signature
from torch._subclasses.fake_tensor import FakeTensor


class TensorI64toI32(ExportPass):
"""
Insert a cast node to cast dtype from int64 to int32.
This will only be applied on fake tensors.
"""

cast_ops = {
torch.ops.aten.argmin.default,
}

def __init__(self, edge_program):
super(TensorI64toI32, self).__init__()
self.edge_program = edge_program

# pyre-ignore[2]
def _is_tensor_of_dtype(self, node_val, dtype: torch.dtype) -> bool:
return isinstance(node_val, FakeTensor) and node_val.dtype == dtype

def _cast_to_int32(self, core_ep: ExirExportedProgram):
copy_op = torch.ops.aten._to_copy.default
for n in core_ep.exported_program.graph.nodes:
# Keep track of original output dtype so we ensure the dtype of the graph is consistent with nn.Module
if is_graph_output(n):
if isinstance(n.meta["val"], tuple):
dtype_list = [tensor.dtype for tensor in n.meta["val"]]
n.meta[QCOM_ORIG_DTYPE] = dtype_list
else:
n.meta[QCOM_ORIG_DTYPE] = n.meta["val"].dtype
continue
if n.target in self.cast_ops:
node_val = n.meta["val"]
if self._is_tensor_of_dtype(node_val, torch.int64):
with core_ep.exported_program.graph.inserting_after(n):
users = list(n.users.keys())
args = (n,)
cast_node = core_ep.exported_program.graph.create_node(
"call_function",
copy_op,
args,
{"dtype": torch.int32},
)
cast_node.meta["val"] = node_val.to(torch.int32)
cast_node.args = args

for user in users:
user.replace_input_with(n, cast_node)

core_ep.exported_program._graph_signature = _get_updated_graph_signature(
core_ep.exported_program._graph_signature,
core_ep.exported_program.graph_module,
)
core_ep.exported_program._validate()

def _preserve_output_dtype(
self, exported_program: torch.export.exported_program.ExportedProgram
):
graph_module = exported_program.graph_module
copy_op = exir_ops.edge.aten._to_copy.default
for n in graph_module.graph.nodes:
if is_graph_output(n) and QCOM_ORIG_DTYPE in n.meta:
if isinstance(n.meta["val"], tuple):
for i, dtype in enumerate(n.meta[QCOM_ORIG_DTYPE]):
# TODO: Enable this in future to support OP such as topK
if n.meta["val"][i].dtype != dtype:
raise AssertionError(
"Multi output nodes currently don't support casting dtype back."
)
elif n.meta["val"].dtype != n.meta[QCOM_ORIG_DTYPE]:
if n.meta[QCOM_ORIG_DTYPE] != torch.int64:
logging.warning(
"This pass is intended to maintain output as int64 when nn.Module outputs int64. Other dtype modification is detected. Please ensure this is desired."
)
with graph_module.graph.inserting_after(n):
orig_dtype = n.meta[QCOM_ORIG_DTYPE]
node_val = n.meta["val"]
args = (n,)
users = list(n.users.keys())
output_users = [
user for user in users if user.target == "output"
]
cast_node = graph_module.graph.create_node(
"call_function",
copy_op,
args,
{"dtype": orig_dtype},
)
cast_node.meta["val"] = node_val.to(orig_dtype)
cast_node.args = args
for user in output_users:
user.replace_input_with(n, cast_node)

def call(self, graph_module: torch.fx.GraphModule):
# Stage 1: _cast_to_int32
# We add to_copy after the desired operations during this stage because the data type only propagates before to_edge.
# If we don't add to_copy here but do it after to_edge, the next operation after to_copy() will still expect int64 as its output.
# Stage 2: _preserve_output_dtype
# We will tag the output dtype during stage 1, and we will ensure that if user expects int64 as output,
# we need to convert the output back to int64 if it is casted from int64->int32 during stage 1.
if isinstance(self.edge_program, ExirExportedProgram):
self._cast_to_int32(self.edge_program)
self.edge_program.exported_program.graph_module.recompile()
elif isinstance(
self.edge_program, torch.export.exported_program.ExportedProgram
):
self._preserve_output_dtype(self.edge_program)
else:
raise AssertionError(
"Should be ExirExportedProgram at stage 1 and torch.export.exported_program.ExportedProgram at stage 2"
)
return PassResult(graph_module, True)
6 changes: 4 additions & 2 deletions backends/qualcomm/_passes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,19 @@ def get_passes_dependency_for_capture_program():
AnnotateAndQuantScalar,
AnnotateDecomposed,
AnnotateQuantAttrs,
ConstantI64toI32,
ConvertBmmToMatmul,
ConvertInterpolateWithUpsample2D,
ConvertPReLU,
ConvertToLinear,
ExpandBroadcastTensorShape,
FoldQDQ,
I64toI32,
LayoutTransform,
RecomposePixelUnshuffle,
RecomposeRmsNorm,
RemoveRedundancy,
ReplaceIndexPutInput,
TensorI64toI32,
)

return {
Expand All @@ -81,7 +82,8 @@ def get_passes_dependency_for_capture_program():
ConvertPReLU: [RemoveRedundancy],
ConvertBmmToMatmul: [ConvertToLinear],
ConvertInterpolateWithUpsample2D: [RemoveRedundancy],
I64toI32: [RemoveRedundancy],
ConstantI64toI32: [RemoveRedundancy],
TensorI64toI32: [RemoveRedundancy],
AnnotateQuantAttrs: [
RecomposePixelUnshuffle,
RecomposeRmsNorm,
Expand Down
52 changes: 41 additions & 11 deletions backends/qualcomm/builders/op_argmin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import torch
from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA

from .node_visitor import NodeVisitor, register_node_visitor
from .qnn_constants import OpArgmin, QNN_OP_PACKAGE_NAME_QTI_AISW
from .node_visitor import NodeVisitor, QNN_TENSOR_TYPE_MAP, register_node_visitor
from .qnn_constants import OpArgmin, OpCast, QNN_OP_PACKAGE_NAME_QTI_AISW


@register_node_visitor
Expand All @@ -26,8 +26,10 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
op_wrapper_list = []
input_node = node.args[0]
input_tensor = self.get_tensor(input_node, node)
output_tensor = self.get_tensor(node, node)
argmin_inp_tensor_wrapper = self.define_tensor(
input_node,
node,
Expand All @@ -37,17 +39,25 @@ def define_node(
)
argmin_input_tensors = [argmin_inp_tensor_wrapper]

output_tensor = self.get_tensor(node, node).to(torch.int32)
# arg output is index, do not quantize it.
node.meta.pop("quant_attrs", None)
output_tensor_wrapper = self.define_tensor(
node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
input_quant_encoding, input_quant_configs = self.get_quant_encoding_conf(
input_node, node
)
argmin_output_tensors = [output_tensor_wrapper]

argmin_intermediate_tensor_wrapper = self.define_custom_tensor_wrapper(
node_name=node.name + "_cast",
tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
dtype=QNN_TENSOR_TYPE_MAP[torch.int32],
quant_encoding=input_quant_encoding,
quant_configs=input_quant_configs,
dims=output_tensor.size(),
tensor=output_tensor,
is_fake_tensor=True,
nodes_to_wrappers=nodes_to_wrappers,
)

argmin_output_tensors = [argmin_intermediate_tensor_wrapper]

dim = cast(int, node.args[1])
if dim < 0:
Expand Down Expand Up @@ -77,4 +87,24 @@ def define_node(
{QCOM_DATA: keep_dims},
)

return argmin_op
op_wrapper_list.append(argmin_op)

cast_op = PyQnnWrapper.PyQnnOpWrapper(
node.name + "_cast",
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpCast.op_name,
)

output_tensor_wrapper = self.define_tensor(
node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)

cast_op.AddInputTensors([argmin_intermediate_tensor_wrapper])
cast_op.AddOutputTensors([output_tensor_wrapper])
op_wrapper_list.append(cast_op)

return op_wrapper_list
4 changes: 2 additions & 2 deletions backends/qualcomm/builders/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,14 @@ def is_graph_input(
return tensor.op == "placeholder" and not is_parameter(tensor, edge_program)


def is_graph_output(tensor: torch.fx.Node) -> bool:
def is_graph_output(node: torch.fx.Node) -> bool:
"""
Check if the given tensor is used as a graph output

Args:
tensor: EdgeIR Tensor that is being checked for graph input
"""
for user in tensor.users.keys():
for user in node.users.keys():
# getitem node is skiped, check the op_skip_ops.py
if user.op == "output" or (
user.target.__name__ == "getitem" and is_graph_output(user)
Expand Down
17 changes: 16 additions & 1 deletion backends/qualcomm/quantizer/annotators.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,21 @@ def annotate_in_out_obs_sharing_op(
)


def annotate_single_in(node: Node, quantization_config: QuantizationConfig) -> None:
if _is_annotated([node]):
return

input_qspec_map = {}
input_act = node.args[0]
assert isinstance(input_act, Node)
input_qspec_map[input_act] = quantization_config.input_activation

node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
_annotated=True,
)


def annotate_single_in_single_out(
node: Node, quantization_config: QuantizationConfig
) -> None:
Expand Down Expand Up @@ -171,7 +186,7 @@ def annotate_add(node: Node, quantization_config: QuantizationConfig) -> None:
def annotate_argmin(node: Node, quantization_config: QuantizationConfig) -> None:
if _is_annotated([node]):
return
annotate_single_in_single_out(node, quantization_config)
annotate_single_in(node, quantization_config)


@register_annotator([torch.ops.aten.sub, torch.ops.aten.sub.Tensor])
Expand Down
27 changes: 27 additions & 0 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,33 @@ def forward(self, y):
)


class Argmin(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
x = torch.argmin(x, dim=0, keepdim=True)
return x


class ArgminViewSqueezeConv2D(torch.nn.Module):
def __init__(self):
# This model is mainly to test the PASS TensorI64toI32
super().__init__()
self.conv = torch.nn.Conv2d(
in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1
)

def forward(self, x, y):
argmin_out = torch.argmin(x, dim=0, keepdim=True)
index_out = y[argmin_out]
conv_out = self.conv(index_out)

view_out = argmin_out.view(-1)
squeeze_out = view_out.squeeze(-1)
return squeeze_out, conv_out


class AvgPoolModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
Loading
Loading