Skip to content

Commit 5c52fbe

Browse files
authored
Qualcomm AI Engine Direct - Fix Argmin (#8308)
* - Annotate input only for argmin. - Update argmin opbuilder so it outputs int64, aligning with Pytorch - Add a pass to cast argmin output to int32 since most OP in QNN does not support int64 * Update library path
1 parent 78752a0 commit 5c52fbe

File tree

13 files changed

+261
-31
lines changed

13 files changed

+261
-31
lines changed

backends/qualcomm/_passes/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
from .annotate_and_quant_scalar import AnnotateAndQuantScalar
22
from .annotate_decomposed import AnnotateDecomposed
33
from .annotate_quant_attrs import AnnotateQuantAttrs
4+
from .constant_i64_to_i32 import ConstantI64toI32
45
from .convert_bmm_to_matmul import ConvertBmmToMatmul
56
from .convert_interpolate_with_upsample2d import ConvertInterpolateWithUpsample2D
67
from .convert_prelu import ConvertPReLU
78
from .convert_to_linear import ConvertToLinear
89
from .expand_broadcast_tensor_shape import ExpandBroadcastTensorShape
910
from .fold_qdq import FoldQDQ
10-
from .i64_to_i32 import I64toI32
1111
from .layout_transform import LayoutTransform
1212
from .recompose_pixel_unshuffle import RecomposePixelUnshuffle
1313
from .recompose_rms_norm import RecomposeRmsNorm
1414
from .remove_redundancy import RemoveRedundancy
1515
from .replace_index_put_input import ReplaceIndexPutInput
16+
from .tensor_i64_to_i32 import TensorI64toI32
1617

1718

1819
__all__ = [
@@ -25,7 +26,8 @@
2526
ConvertToLinear,
2627
ExpandBroadcastTensorShape,
2728
FoldQDQ,
28-
I64toI32,
29+
ConstantI64toI32,
30+
TensorI64toI32,
2931
LayoutTransform,
3032
RecomposePixelUnshuffle,
3133
RecomposeRmsNorm,

backends/qualcomm/_passes/i64_to_i32.py renamed to backends/qualcomm/_passes/constant_i64_to_i32.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,18 @@
1212
from torch._subclasses.fake_tensor import FakeTensor
1313

1414

15-
class I64toI32(ExportPass):
15+
class ConstantI64toI32(ExportPass):
1616
"""
1717
Cast unsupported int64 datatype into int32.
18+
This will only be applied on constant nodes such as weights.
1819
"""
1920

2021
def __init__(
2122
self,
2223
edge_program: torch.export.ExportedProgram,
2324
skip_node: FrozenSet[str] = frozenset(),
2425
):
25-
super(I64toI32, self).__init__()
26+
super(ConstantI64toI32, self).__init__()
2627
self.edge_program = edge_program
2728
self.skip_node = skip_node
2829
# pyre-ignore[4]

backends/qualcomm/_passes/layout_transform.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ class LayoutTransform(ExportPass):
4545
layout_agnostic_ops = {
4646
exir_ops.edge.aten.abs.default,
4747
exir_ops.edge.aten.add.Tensor,
48-
exir_ops.edge.aten.argmin.default,
4948
exir_ops.edge.aten.bmm.default,
5049
exir_ops.edge.aten.cat.default,
5150
exir_ops.edge.aten.ceil.default,
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
9+
import torch
10+
from executorch.backends.qualcomm.builders.utils import is_graph_output
11+
from executorch.backends.qualcomm.utils.constants import QCOM_ORIG_DTYPE
12+
from executorch.exir import ExirExportedProgram
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
from executorch.exir.pass_base import ExportPass, PassResult
15+
from executorch.exir.program._program import _get_updated_graph_signature
16+
from torch._subclasses.fake_tensor import FakeTensor
17+
18+
19+
class TensorI64toI32(ExportPass):
20+
"""
21+
Insert a cast node to cast dtype from int64 to int32.
22+
This will only be applied on fake tensors.
23+
"""
24+
25+
cast_ops = {
26+
torch.ops.aten.argmin.default,
27+
}
28+
29+
def __init__(self, edge_program):
30+
super(TensorI64toI32, self).__init__()
31+
self.edge_program = edge_program
32+
33+
# pyre-ignore[2]
34+
def _is_tensor_of_dtype(self, node_val, dtype: torch.dtype) -> bool:
35+
return isinstance(node_val, FakeTensor) and node_val.dtype == dtype
36+
37+
def _cast_to_int32(self, core_ep: ExirExportedProgram):
38+
copy_op = torch.ops.aten._to_copy.default
39+
for n in core_ep.exported_program.graph.nodes:
40+
# Keep track of original output dtype so we ensure the dtype of the graph is consistent with nn.Module
41+
if is_graph_output(n):
42+
if isinstance(n.meta["val"], tuple):
43+
dtype_list = [tensor.dtype for tensor in n.meta["val"]]
44+
n.meta[QCOM_ORIG_DTYPE] = dtype_list
45+
else:
46+
n.meta[QCOM_ORIG_DTYPE] = n.meta["val"].dtype
47+
continue
48+
if n.target in self.cast_ops:
49+
node_val = n.meta["val"]
50+
if self._is_tensor_of_dtype(node_val, torch.int64):
51+
with core_ep.exported_program.graph.inserting_after(n):
52+
users = list(n.users.keys())
53+
args = (n,)
54+
cast_node = core_ep.exported_program.graph.create_node(
55+
"call_function",
56+
copy_op,
57+
args,
58+
{"dtype": torch.int32},
59+
)
60+
cast_node.meta["val"] = node_val.to(torch.int32)
61+
cast_node.args = args
62+
63+
for user in users:
64+
user.replace_input_with(n, cast_node)
65+
66+
core_ep.exported_program._graph_signature = _get_updated_graph_signature(
67+
core_ep.exported_program._graph_signature,
68+
core_ep.exported_program.graph_module,
69+
)
70+
core_ep.exported_program._validate()
71+
72+
def _preserve_output_dtype(
73+
self, exported_program: torch.export.exported_program.ExportedProgram
74+
):
75+
graph_module = exported_program.graph_module
76+
copy_op = exir_ops.edge.aten._to_copy.default
77+
for n in graph_module.graph.nodes:
78+
if is_graph_output(n) and QCOM_ORIG_DTYPE in n.meta:
79+
if isinstance(n.meta["val"], tuple):
80+
for i, dtype in enumerate(n.meta[QCOM_ORIG_DTYPE]):
81+
# TODO: Enable this in future to support OP such as topK
82+
if n.meta["val"][i].dtype != dtype:
83+
raise AssertionError(
84+
"Multi output nodes currently don't support casting dtype back."
85+
)
86+
elif n.meta["val"].dtype != n.meta[QCOM_ORIG_DTYPE]:
87+
if n.meta[QCOM_ORIG_DTYPE] != torch.int64:
88+
logging.warning(
89+
"This pass is intended to maintain output as int64 when nn.Module outputs int64. Other dtype modification is detected. Please ensure this is desired."
90+
)
91+
with graph_module.graph.inserting_after(n):
92+
orig_dtype = n.meta[QCOM_ORIG_DTYPE]
93+
node_val = n.meta["val"]
94+
args = (n,)
95+
users = list(n.users.keys())
96+
output_users = [
97+
user for user in users if user.target == "output"
98+
]
99+
cast_node = graph_module.graph.create_node(
100+
"call_function",
101+
copy_op,
102+
args,
103+
{"dtype": orig_dtype},
104+
)
105+
cast_node.meta["val"] = node_val.to(orig_dtype)
106+
cast_node.args = args
107+
for user in output_users:
108+
user.replace_input_with(n, cast_node)
109+
110+
def call(self, graph_module: torch.fx.GraphModule):
111+
# Stage 1: _cast_to_int32
112+
# We add to_copy after the desired operations during this stage because the data type only propagates before to_edge.
113+
# If we don't add to_copy here but do it after to_edge, the next operation after to_copy() will still expect int64 as its output.
114+
# Stage 2: _preserve_output_dtype
115+
# We will tag the output dtype during stage 1, and we will ensure that if user expects int64 as output,
116+
# we need to convert the output back to int64 if it is casted from int64->int32 during stage 1.
117+
if isinstance(self.edge_program, ExirExportedProgram):
118+
self._cast_to_int32(self.edge_program)
119+
self.edge_program.exported_program.graph_module.recompile()
120+
elif isinstance(
121+
self.edge_program, torch.export.exported_program.ExportedProgram
122+
):
123+
self._preserve_output_dtype(self.edge_program)
124+
else:
125+
raise AssertionError(
126+
"Should be ExirExportedProgram at stage 1 and torch.export.exported_program.ExportedProgram at stage 2"
127+
)
128+
return PassResult(graph_module, True)

backends/qualcomm/_passes/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,18 +60,19 @@ def get_passes_dependency_for_capture_program():
6060
AnnotateAndQuantScalar,
6161
AnnotateDecomposed,
6262
AnnotateQuantAttrs,
63+
ConstantI64toI32,
6364
ConvertBmmToMatmul,
6465
ConvertInterpolateWithUpsample2D,
6566
ConvertPReLU,
6667
ConvertToLinear,
6768
ExpandBroadcastTensorShape,
6869
FoldQDQ,
69-
I64toI32,
7070
LayoutTransform,
7171
RecomposePixelUnshuffle,
7272
RecomposeRmsNorm,
7373
RemoveRedundancy,
7474
ReplaceIndexPutInput,
75+
TensorI64toI32,
7576
)
7677

7778
return {
@@ -81,7 +82,8 @@ def get_passes_dependency_for_capture_program():
8182
ConvertPReLU: [RemoveRedundancy],
8283
ConvertBmmToMatmul: [ConvertToLinear],
8384
ConvertInterpolateWithUpsample2D: [RemoveRedundancy],
84-
I64toI32: [RemoveRedundancy],
85+
ConstantI64toI32: [RemoveRedundancy],
86+
TensorI64toI32: [RemoveRedundancy],
8587
AnnotateQuantAttrs: [
8688
RecomposePixelUnshuffle,
8789
RecomposeRmsNorm,

backends/qualcomm/builders/op_argmin.py

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
import torch
1111
from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA
1212

13-
from .node_visitor import NodeVisitor, register_node_visitor
14-
from .qnn_constants import OpArgmin, QNN_OP_PACKAGE_NAME_QTI_AISW
13+
from .node_visitor import NodeVisitor, QNN_TENSOR_TYPE_MAP, register_node_visitor
14+
from .qnn_constants import OpArgmin, OpCast, QNN_OP_PACKAGE_NAME_QTI_AISW
1515

1616

1717
@register_node_visitor
@@ -26,8 +26,10 @@ def define_node(
2626
node: torch.fx.Node,
2727
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
2828
) -> PyQnnWrapper.PyQnnOpWrapper:
29+
op_wrapper_list = []
2930
input_node = node.args[0]
3031
input_tensor = self.get_tensor(input_node, node)
32+
output_tensor = self.get_tensor(node, node)
3133
argmin_inp_tensor_wrapper = self.define_tensor(
3234
input_node,
3335
node,
@@ -37,17 +39,25 @@ def define_node(
3739
)
3840
argmin_input_tensors = [argmin_inp_tensor_wrapper]
3941

40-
output_tensor = self.get_tensor(node, node).to(torch.int32)
4142
# arg output is index, do not quantize it.
4243
node.meta.pop("quant_attrs", None)
43-
output_tensor_wrapper = self.define_tensor(
44-
node,
45-
node,
46-
output_tensor,
47-
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
48-
nodes_to_wrappers,
44+
input_quant_encoding, input_quant_configs = self.get_quant_encoding_conf(
45+
input_node, node
4946
)
50-
argmin_output_tensors = [output_tensor_wrapper]
47+
48+
argmin_intermediate_tensor_wrapper = self.define_custom_tensor_wrapper(
49+
node_name=node.name + "_cast",
50+
tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
51+
dtype=QNN_TENSOR_TYPE_MAP[torch.int32],
52+
quant_encoding=input_quant_encoding,
53+
quant_configs=input_quant_configs,
54+
dims=output_tensor.size(),
55+
tensor=output_tensor,
56+
is_fake_tensor=True,
57+
nodes_to_wrappers=nodes_to_wrappers,
58+
)
59+
60+
argmin_output_tensors = [argmin_intermediate_tensor_wrapper]
5161

5262
dim = cast(int, node.args[1])
5363
if dim < 0:
@@ -77,4 +87,24 @@ def define_node(
7787
{QCOM_DATA: keep_dims},
7888
)
7989

80-
return argmin_op
90+
op_wrapper_list.append(argmin_op)
91+
92+
cast_op = PyQnnWrapper.PyQnnOpWrapper(
93+
node.name + "_cast",
94+
QNN_OP_PACKAGE_NAME_QTI_AISW,
95+
OpCast.op_name,
96+
)
97+
98+
output_tensor_wrapper = self.define_tensor(
99+
node,
100+
node,
101+
output_tensor,
102+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
103+
nodes_to_wrappers,
104+
)
105+
106+
cast_op.AddInputTensors([argmin_intermediate_tensor_wrapper])
107+
cast_op.AddOutputTensors([output_tensor_wrapper])
108+
op_wrapper_list.append(cast_op)
109+
110+
return op_wrapper_list

backends/qualcomm/builders/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,14 +75,14 @@ def is_graph_input(
7575
return tensor.op == "placeholder" and not is_parameter(tensor, edge_program)
7676

7777

78-
def is_graph_output(tensor: torch.fx.Node) -> bool:
78+
def is_graph_output(node: torch.fx.Node) -> bool:
7979
"""
8080
Check if the given tensor is used as a graph output
8181
8282
Args:
8383
tensor: EdgeIR Tensor that is being checked for graph input
8484
"""
85-
for user in tensor.users.keys():
85+
for user in node.users.keys():
8686
# getitem node is skiped, check the op_skip_ops.py
8787
if user.op == "output" or (
8888
user.target.__name__ == "getitem" and is_graph_output(user)

backends/qualcomm/quantizer/annotators.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,21 @@ def annotate_in_out_obs_sharing_op(
110110
)
111111

112112

113+
def annotate_single_in(node: Node, quantization_config: QuantizationConfig) -> None:
114+
if _is_annotated([node]):
115+
return
116+
117+
input_qspec_map = {}
118+
input_act = node.args[0]
119+
assert isinstance(input_act, Node)
120+
input_qspec_map[input_act] = quantization_config.input_activation
121+
122+
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
123+
input_qspec_map=input_qspec_map,
124+
_annotated=True,
125+
)
126+
127+
113128
def annotate_single_in_single_out(
114129
node: Node, quantization_config: QuantizationConfig
115130
) -> None:
@@ -171,7 +186,7 @@ def annotate_add(node: Node, quantization_config: QuantizationConfig) -> None:
171186
def annotate_argmin(node: Node, quantization_config: QuantizationConfig) -> None:
172187
if _is_annotated([node]):
173188
return
174-
annotate_single_in_single_out(node, quantization_config)
189+
annotate_single_in(node, quantization_config)
175190

176191

177192
@register_annotator([torch.ops.aten.sub, torch.ops.aten.sub.Tensor])

backends/qualcomm/tests/models.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,33 @@ def forward(self, y):
6666
)
6767

6868

69+
class Argmin(torch.nn.Module):
70+
def __init__(self):
71+
super().__init__()
72+
73+
def forward(self, x):
74+
x = torch.argmin(x, dim=0, keepdim=True)
75+
return x
76+
77+
78+
class ArgminViewSqueezeConv2D(torch.nn.Module):
79+
def __init__(self):
80+
# This model is mainly to test the PASS TensorI64toI32
81+
super().__init__()
82+
self.conv = torch.nn.Conv2d(
83+
in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1
84+
)
85+
86+
def forward(self, x, y):
87+
argmin_out = torch.argmin(x, dim=0, keepdim=True)
88+
index_out = y[argmin_out]
89+
conv_out = self.conv(index_out)
90+
91+
view_out = argmin_out.view(-1)
92+
squeeze_out = view_out.squeeze(-1)
93+
return squeeze_out, conv_out
94+
95+
6996
class AvgPoolModule(torch.nn.Module):
7097
def __init__(self):
7198
super().__init__()

0 commit comments

Comments
 (0)