Skip to content

Commit cfba192

Browse files
authored
Qualcomm AI Engine Direct - op support (#8306)
- where / logical_not - test cases
1 parent b751645 commit cfba192

File tree

9 files changed

+237
-3
lines changed

9 files changed

+237
-3
lines changed

backends/qualcomm/_passes/layout_transform.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ class LayoutTransform(ExportPass):
6868
exir_ops.edge.aten.le.Tensor,
6969
exir_ops.edge.aten.linear.default,
7070
exir_ops.edge.aten.log.default,
71+
exir_ops.edge.aten.logical_not.default,
7172
exir_ops.edge.aten.lt.Scalar,
7273
exir_ops.edge.aten.lt.Tensor,
7374
exir_ops.edge.aten._log_softmax.default,
@@ -88,6 +89,7 @@ class LayoutTransform(ExportPass):
8889
exir_ops.edge.aten.sum.dim_IntList,
8990
exir_ops.edge.aten.topk.default,
9091
exir_ops.edge.aten._to_copy.default,
92+
exir_ops.edge.aten.where.self,
9193
*q_ops,
9294
*dq_ops,
9395
_operator.getitem,

backends/qualcomm/builders/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
op_linear,
4141
op_log,
4242
op_log_softmax,
43+
op_logical_not,
4344
op_lt,
4445
op_matmul,
4546
op_max,
@@ -76,6 +77,7 @@
7677
op_unsqueeze,
7778
op_upsample_bilinear2d,
7879
op_upsample_nearest2d,
80+
op_where,
7981
)
8082

8183
__all__ = [
@@ -113,6 +115,7 @@
113115
op_le,
114116
op_linear,
115117
op_log,
118+
op_logical_not,
116119
op_log_softmax,
117120
op_lt,
118121
op_matmul,
@@ -150,4 +153,5 @@
150153
op_unsqueeze,
151154
op_upsample_bilinear2d,
152155
op_upsample_nearest2d,
156+
op_where,
153157
]
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import Dict
7+
8+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9+
10+
import torch
11+
12+
from .node_visitor import NodeVisitor, register_node_visitor
13+
from .qnn_constants import OpElementWiseNot, QNN_OP_PACKAGE_NAME_QTI_AISW
14+
15+
16+
@register_node_visitor
17+
class Not(NodeVisitor):
18+
target = ["aten.logical_not.default"]
19+
20+
def __init__(self, *args) -> None:
21+
super().__init__(*args)
22+
23+
def define_node(
24+
self,
25+
node: torch.fx.Node,
26+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
27+
) -> PyQnnWrapper.PyQnnOpWrapper:
28+
input_node = node.args[0]
29+
input_tensor = self.get_tensor(input_node, node)
30+
input_tensor_wrapper = self.define_tensor(
31+
input_node,
32+
node,
33+
input_tensor,
34+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
35+
nodes_to_wrappers,
36+
)
37+
38+
output_tensor = self.get_tensor(node, node)
39+
output_tensor_wrapper = self.define_tensor(
40+
node,
41+
node,
42+
output_tensor,
43+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
44+
nodes_to_wrappers,
45+
)
46+
47+
logical_not_op = PyQnnWrapper.PyQnnOpWrapper(
48+
node.name,
49+
QNN_OP_PACKAGE_NAME_QTI_AISW,
50+
OpElementWiseNot.op_name,
51+
)
52+
logical_not_op.AddInputTensors([input_tensor_wrapper])
53+
logical_not_op.AddOutputTensors([output_tensor_wrapper])
54+
55+
return logical_not_op
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import Dict
7+
8+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9+
10+
import torch
11+
12+
from .node_visitor import NodeVisitor, register_node_visitor
13+
from .qnn_constants import OpElementWiseSelect, QNN_OP_PACKAGE_NAME_QTI_AISW
14+
15+
16+
@register_node_visitor
17+
class Where(NodeVisitor):
18+
target = ["aten.where.self"]
19+
20+
def __init__(self, *args) -> None:
21+
super().__init__(*args)
22+
23+
def define_node(
24+
self,
25+
node: torch.fx.Node,
26+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
27+
) -> PyQnnWrapper.PyQnnOpWrapper:
28+
conditional_input_node = node.args[0]
29+
conditional_input_tensor = self.get_tensor(conditional_input_node, node)
30+
conditional_input_tensor_wrapper = self.define_tensor(
31+
conditional_input_node,
32+
node,
33+
conditional_input_tensor,
34+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
35+
nodes_to_wrappers,
36+
)
37+
38+
true_input_node = node.args[1]
39+
true_input_tensor = self.get_tensor(true_input_node, node)
40+
true_input_tensor_wrapper = self.define_tensor(
41+
true_input_node,
42+
node,
43+
true_input_tensor,
44+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
45+
nodes_to_wrappers,
46+
)
47+
48+
false_input_node = node.args[2]
49+
false_input_tensor = self.get_tensor(false_input_node, node)
50+
false_input_tensor_wrapper = self.define_tensor(
51+
false_input_node,
52+
node,
53+
false_input_tensor,
54+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
55+
nodes_to_wrappers,
56+
)
57+
58+
output_tensor = self.get_tensor(node, node)
59+
output_tensor_wrapper = self.define_tensor(
60+
node,
61+
node,
62+
output_tensor,
63+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
64+
nodes_to_wrappers,
65+
)
66+
67+
where_op = PyQnnWrapper.PyQnnOpWrapper(
68+
node.name,
69+
QNN_OP_PACKAGE_NAME_QTI_AISW,
70+
OpElementWiseSelect.op_name,
71+
)
72+
where_op.AddInputTensors(
73+
[
74+
conditional_input_tensor_wrapper,
75+
true_input_tensor_wrapper,
76+
false_input_tensor_wrapper,
77+
]
78+
)
79+
where_op.AddOutputTensors([output_tensor_wrapper])
80+
81+
return where_op

backends/qualcomm/builders/qnn_constants.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,11 @@ class OpElementWiseNeuron:
158158
param_beta: str = "beta"
159159

160160

161+
@dataclass(init=False, frozen=True)
162+
class OpElementWiseNot:
163+
op_name: str = "ElementWiseNot"
164+
165+
161166
@dataclass(init=False, frozen=True)
162167
class OpElementWisePower:
163168
op_name: str = "ElementWisePower"
@@ -173,6 +178,11 @@ class OpElementWiseSin:
173178
op_name: str = "ElementWiseSin"
174179

175180

181+
@dataclass(init=False, frozen=True)
182+
class OpElementWiseSelect:
183+
op_name = "ElementWiseSelect"
184+
185+
176186
@dataclass(init=False, frozen=True)
177187
class OpElementWiseSubtract:
178188
op_name = "ElementWiseSubtract"

backends/qualcomm/partition/common_defs.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
from executorch.exir.dialects._ops import ops as exir_ops
99

10-
1110
not_supported_operator = [
1211
exir_ops.edge.aten.clone.default,
1312
exir_ops.edge.aten.full.default,
@@ -18,8 +17,6 @@
1817

1918
to_be_implemented_operator = [
2019
exir_ops.edge.aten.any.dim,
21-
exir_ops.edge.aten.logical_not.default,
22-
exir_ops.edge.aten.where.self,
2320
]
2421

2522
constant_operator = [

backends/qualcomm/quantizer/annotators.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,3 +1070,26 @@ def annotate_chunk(node: Node, quantization_config: QuantizationConfig) -> None:
10701070
output_qspec=quantization_config.output_activation,
10711071
_annotated=True,
10721072
)
1073+
1074+
1075+
@register_annotator([torch.ops.aten.where.self])
1076+
def annotate_where(node: Node, quantization_config: QuantizationConfig) -> None:
1077+
true_input_act = node.args[1]
1078+
false_input_act = node.args[2]
1079+
if _is_annotated([node]):
1080+
return
1081+
1082+
_annotate_input_qspec_map(
1083+
node,
1084+
true_input_act,
1085+
quantization_config.input_activation,
1086+
)
1087+
1088+
_annotate_input_qspec_map(
1089+
node,
1090+
false_input_act,
1091+
quantization_config.input_activation,
1092+
)
1093+
1094+
_annotate_output_qspec(node, quantization_config.output_activation)
1095+
_mark_nodes_as_annotated([node])

backends/qualcomm/tests/models.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -793,6 +793,14 @@ def forward(self, x):
793793
return torch.log(x)
794794

795795

796+
class LogicalNot(torch.nn.Module):
797+
def __init__(self):
798+
super().__init__()
799+
800+
def forward(self, x):
801+
return torch.logical_not(x > 0)
802+
803+
796804
class LogSoftmax(torch.nn.Module):
797805
def __init__(self):
798806
super().__init__()
@@ -1306,3 +1314,21 @@ def forward(self, x, y):
13061314
x = x.view(new_shape)
13071315
x = x.permute(0, 2, 1, 3)
13081316
return torch.matmul(x, y.transpose(-1, -2))
1317+
1318+
1319+
class Where(torch.nn.Module):
1320+
def __init__(self):
1321+
super().__init__()
1322+
1323+
def forward(self, x, y, z):
1324+
return torch.where(x >= torch.zeros(x.shape), y, z)
1325+
1326+
1327+
class WhereConstant(torch.nn.Module):
1328+
def __init__(self, pos, neg):
1329+
super().__init__()
1330+
self.register_buffer("pos", pos)
1331+
self.register_buffer("neg", neg)
1332+
1333+
def forward(self, x):
1334+
return torch.where(x >= torch.zeros(x.shape), self.pos, self.neg)

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,11 @@ def test_qnn_backend_log(self):
517517
sample_input = (torch.rand([1, 2, 3, 4]),)
518518
self.lower_module_and_test_output(module, sample_input)
519519

520+
def test_qnn_backend_logical_not(self):
521+
module = LogicalNot() # noqa: F405
522+
sample_input = (torch.rand([1, 2, 3, 4]),)
523+
self.lower_module_and_test_output(module, sample_input)
524+
520525
def test_qnn_backend_log_softmax(self):
521526
module = LogSoftmax() # noqa: F405
522527
sample_input = (torch.randn([1, 4, 8, 8]),)
@@ -696,6 +701,18 @@ def test_qnn_backend_view(self):
696701
sample_input = (torch.randn([1, 8, 512]), torch.randn([1, 2, 8, 256]))
697702
self.lower_module_and_test_output(module, sample_input)
698703

704+
def test_qnn_backend_where(self):
705+
modules = [
706+
Where(), # noqa: F405
707+
WhereConstant(torch.randn(3, 2), torch.randn(3, 2)), # noqa: F405
708+
]
709+
sample_inputs = [
710+
(torch.randn(3, 2), torch.randn(3, 2), torch.randn(3, 2)),
711+
(torch.randn(3, 2),),
712+
]
713+
for i, module in enumerate(modules):
714+
self.lower_module_and_test_output(module, sample_inputs[i])
715+
699716

700717
class TestQNNFloatingPointModel(TestQNN):
701718
# TODO: refactor to support different backends
@@ -1400,6 +1417,12 @@ def test_qnn_backend_log(self):
14001417
module = self.get_qdq_module(module, sample_input)
14011418
self.lower_module_and_test_output(module, sample_input)
14021419

1420+
def test_qnn_backend_logical_not(self):
1421+
module = LogicalNot() # noqa: F405
1422+
sample_input = (torch.rand([1, 2, 3, 4]),)
1423+
module = self.get_qdq_module(module, sample_input)
1424+
self.lower_module_and_test_output(module, sample_input)
1425+
14031426
def test_qnn_backend_log_softmax(self):
14041427
module = LogSoftmax() # noqa: F405
14051428
sample_input = (torch.randn([1, 4, 8, 8]),)
@@ -1613,6 +1636,19 @@ def test_qnn_backend_view(self):
16131636
module = self.get_qdq_module(module, sample_input)
16141637
self.lower_module_and_test_output(module, sample_input)
16151638

1639+
def test_qnn_backend_where(self):
1640+
modules = [
1641+
Where(), # noqa: F405
1642+
WhereConstant(torch.randn(3, 2), torch.randn(3, 2)), # noqa: F405
1643+
]
1644+
sample_inputs = [
1645+
(torch.randn(3, 2), torch.randn(3, 2), torch.randn(3, 2)),
1646+
(torch.randn(3, 2),),
1647+
]
1648+
for i, module in enumerate(modules):
1649+
module = self.get_qdq_module(module, sample_inputs[i])
1650+
self.lower_module_and_test_output(module, sample_inputs[i])
1651+
16161652

16171653
class TestQNNQuantizedModel(TestQNN):
16181654
# TODO: refactor to support different backends

0 commit comments

Comments
 (0)