Skip to content

Commit e268aa5

Browse files
committed
Qualcomm AI Engine Direct - op support
- where / logical_not - test cases
1 parent b1d76c9 commit e268aa5

File tree

9 files changed

+237
-3
lines changed

9 files changed

+237
-3
lines changed

backends/qualcomm/_passes/layout_transform.py

+2
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ class LayoutTransform(ExportPass):
6868
exir_ops.edge.aten.le.Tensor,
6969
exir_ops.edge.aten.linear.default,
7070
exir_ops.edge.aten.log.default,
71+
exir_ops.edge.aten.logical_not.default,
7172
exir_ops.edge.aten.lt.Scalar,
7273
exir_ops.edge.aten.lt.Tensor,
7374
exir_ops.edge.aten._log_softmax.default,
@@ -88,6 +89,7 @@ class LayoutTransform(ExportPass):
8889
exir_ops.edge.aten.sum.dim_IntList,
8990
exir_ops.edge.aten.topk.default,
9091
exir_ops.edge.aten._to_copy.default,
92+
exir_ops.edge.aten.where.self,
9193
*q_ops,
9294
*dq_ops,
9395
_operator.getitem,

backends/qualcomm/builders/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
op_linear,
4141
op_log,
4242
op_log_softmax,
43+
op_logical_not,
4344
op_lt,
4445
op_matmul,
4546
op_max,
@@ -76,6 +77,7 @@
7677
op_unsqueeze,
7778
op_upsample_bilinear2d,
7879
op_upsample_nearest2d,
80+
op_where,
7981
)
8082

8183
__all__ = [
@@ -113,6 +115,7 @@
113115
op_le,
114116
op_linear,
115117
op_log,
118+
op_logical_not,
116119
op_log_softmax,
117120
op_lt,
118121
op_matmul,
@@ -150,4 +153,5 @@
150153
op_unsqueeze,
151154
op_upsample_bilinear2d,
152155
op_upsample_nearest2d,
156+
op_where,
153157
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import Dict
7+
8+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9+
10+
import torch
11+
12+
from .node_visitor import NodeVisitor, register_node_visitor
13+
from .qnn_constants import OpElementWiseNot, QNN_OP_PACKAGE_NAME_QTI_AISW
14+
15+
16+
@register_node_visitor
17+
class Not(NodeVisitor):
18+
target = ["aten.logical_not.default"]
19+
20+
def __init__(self, *args) -> None:
21+
super().__init__(*args)
22+
23+
def define_node(
24+
self,
25+
node: torch.fx.Node,
26+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
27+
) -> PyQnnWrapper.PyQnnOpWrapper:
28+
input_node = node.args[0]
29+
input_tensor = self.get_tensor(input_node, node)
30+
input_tensor_wrapper = self.define_tensor(
31+
input_node,
32+
node,
33+
input_tensor,
34+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
35+
nodes_to_wrappers,
36+
)
37+
38+
output_tensor = self.get_tensor(node, node)
39+
output_tensor_wrapper = self.define_tensor(
40+
node,
41+
node,
42+
output_tensor,
43+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
44+
nodes_to_wrappers,
45+
)
46+
47+
logical_not_op = PyQnnWrapper.PyQnnOpWrapper(
48+
node.name,
49+
QNN_OP_PACKAGE_NAME_QTI_AISW,
50+
OpElementWiseNot.op_name,
51+
)
52+
logical_not_op.AddInputTensors([input_tensor_wrapper])
53+
logical_not_op.AddOutputTensors([output_tensor_wrapper])
54+
55+
return logical_not_op
+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import Dict
7+
8+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9+
10+
import torch
11+
12+
from .node_visitor import NodeVisitor, register_node_visitor
13+
from .qnn_constants import OpElementWiseSelect, QNN_OP_PACKAGE_NAME_QTI_AISW
14+
15+
16+
@register_node_visitor
17+
class Where(NodeVisitor):
18+
target = ["aten.where.self"]
19+
20+
def __init__(self, *args) -> None:
21+
super().__init__(*args)
22+
23+
def define_node(
24+
self,
25+
node: torch.fx.Node,
26+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
27+
) -> PyQnnWrapper.PyQnnOpWrapper:
28+
conditional_input_node = node.args[0]
29+
conditional_input_tensor = self.get_tensor(conditional_input_node, node)
30+
conditional_input_tensor_wrapper = self.define_tensor(
31+
conditional_input_node,
32+
node,
33+
conditional_input_tensor,
34+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
35+
nodes_to_wrappers,
36+
)
37+
38+
true_input_node = node.args[1]
39+
true_input_tensor = self.get_tensor(true_input_node, node)
40+
true_input_tensor_wrapper = self.define_tensor(
41+
true_input_node,
42+
node,
43+
true_input_tensor,
44+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
45+
nodes_to_wrappers,
46+
)
47+
48+
false_input_node = node.args[2]
49+
false_input_tensor = self.get_tensor(false_input_node, node)
50+
false_input_tensor_wrapper = self.define_tensor(
51+
false_input_node,
52+
node,
53+
false_input_tensor,
54+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
55+
nodes_to_wrappers,
56+
)
57+
58+
output_tensor = self.get_tensor(node, node)
59+
output_tensor_wrapper = self.define_tensor(
60+
node,
61+
node,
62+
output_tensor,
63+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
64+
nodes_to_wrappers,
65+
)
66+
67+
where_op = PyQnnWrapper.PyQnnOpWrapper(
68+
node.name,
69+
QNN_OP_PACKAGE_NAME_QTI_AISW,
70+
OpElementWiseSelect.op_name,
71+
)
72+
where_op.AddInputTensors(
73+
[
74+
conditional_input_tensor_wrapper,
75+
true_input_tensor_wrapper,
76+
false_input_tensor_wrapper,
77+
]
78+
)
79+
where_op.AddOutputTensors([output_tensor_wrapper])
80+
81+
return where_op

backends/qualcomm/builders/qnn_constants.py

+10
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,11 @@ class OpElementWiseNeuron:
158158
param_beta: str = "beta"
159159

160160

161+
@dataclass(init=False, frozen=True)
162+
class OpElementWiseNot:
163+
op_name: str = "ElementWiseNot"
164+
165+
161166
@dataclass(init=False, frozen=True)
162167
class OpElementWisePower:
163168
op_name: str = "ElementWisePower"
@@ -173,6 +178,11 @@ class OpElementWiseSin:
173178
op_name: str = "ElementWiseSin"
174179

175180

181+
@dataclass(init=False, frozen=True)
182+
class OpElementWiseSelect:
183+
op_name = "ElementWiseSelect"
184+
185+
176186
@dataclass(init=False, frozen=True)
177187
class OpElementWiseSubtract:
178188
op_name = "ElementWiseSubtract"

backends/qualcomm/partition/common_defs.py

-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
from executorch.exir.dialects._ops import ops as exir_ops
99

10-
1110
not_supported_operator = [
1211
exir_ops.edge.aten.clone.default,
1312
exir_ops.edge.aten.full.default,
@@ -18,8 +17,6 @@
1817

1918
to_be_implemented_operator = [
2019
exir_ops.edge.aten.any.dim,
21-
exir_ops.edge.aten.logical_not.default,
22-
exir_ops.edge.aten.where.self,
2320
]
2421

2522
constant_operator = [

backends/qualcomm/quantizer/annotators.py

+23
Original file line numberDiff line numberDiff line change
@@ -1070,3 +1070,26 @@ def annotate_chunk(node: Node, quantization_config: QuantizationConfig) -> None:
10701070
output_qspec=quantization_config.output_activation,
10711071
_annotated=True,
10721072
)
1073+
1074+
1075+
@register_annotator([torch.ops.aten.where.self])
1076+
def annotate_where(node: Node, quantization_config: QuantizationConfig) -> None:
1077+
true_input_act = node.args[1]
1078+
false_input_act = node.args[2]
1079+
if _is_annotated([node]):
1080+
return
1081+
1082+
_annotate_input_qspec_map(
1083+
node,
1084+
true_input_act,
1085+
quantization_config.input_activation,
1086+
)
1087+
1088+
_annotate_input_qspec_map(
1089+
node,
1090+
false_input_act,
1091+
quantization_config.input_activation,
1092+
)
1093+
1094+
_annotate_output_qspec(node, quantization_config.output_activation)
1095+
_mark_nodes_as_annotated([node])

backends/qualcomm/tests/models.py

+26
Original file line numberDiff line numberDiff line change
@@ -793,6 +793,14 @@ def forward(self, x):
793793
return torch.log(x)
794794

795795

796+
class LogicalNot(torch.nn.Module):
797+
def __init__(self):
798+
super().__init__()
799+
800+
def forward(self, x):
801+
return torch.logical_not(x > 0)
802+
803+
796804
class LogSoftmax(torch.nn.Module):
797805
def __init__(self):
798806
super().__init__()
@@ -1306,3 +1314,21 @@ def forward(self, x, y):
13061314
x = x.view(new_shape)
13071315
x = x.permute(0, 2, 1, 3)
13081316
return torch.matmul(x, y.transpose(-1, -2))
1317+
1318+
1319+
class Where(torch.nn.Module):
1320+
def __init__(self):
1321+
super().__init__()
1322+
1323+
def forward(self, x, y, z):
1324+
return torch.where(x >= torch.zeros(x.shape), y, z)
1325+
1326+
1327+
class WhereConstant(torch.nn.Module):
1328+
def __init__(self, pos, neg):
1329+
super().__init__()
1330+
self.register_buffer("pos", pos)
1331+
self.register_buffer("neg", neg)
1332+
1333+
def forward(self, x):
1334+
return torch.where(x >= torch.zeros(x.shape), self.pos, self.neg)

backends/qualcomm/tests/test_qnn_delegate.py

+36
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,11 @@ def test_qnn_backend_log(self):
513513
sample_input = (torch.rand([1, 2, 3, 4]),)
514514
self.lower_module_and_test_output(module, sample_input)
515515

516+
def test_qnn_backend_logical_not(self):
517+
module = LogicalNot() # noqa: F405
518+
sample_input = (torch.rand([1, 2, 3, 4]),)
519+
self.lower_module_and_test_output(module, sample_input)
520+
516521
def test_qnn_backend_log_softmax(self):
517522
module = LogSoftmax() # noqa: F405
518523
sample_input = (torch.randn([1, 4, 8, 8]),)
@@ -692,6 +697,18 @@ def test_qnn_backend_view(self):
692697
sample_input = (torch.randn([1, 8, 512]), torch.randn([1, 2, 8, 256]))
693698
self.lower_module_and_test_output(module, sample_input)
694699

700+
def test_qnn_backend_where(self):
701+
modules = [
702+
Where(), # noqa: F405
703+
WhereConstant(torch.randn(3, 2), torch.randn(3, 2)), # noqa: F405
704+
]
705+
sample_inputs = [
706+
(torch.randn(3, 2), torch.randn(3, 2), torch.randn(3, 2)),
707+
(torch.randn(3, 2),),
708+
]
709+
for i, module in enumerate(modules):
710+
self.lower_module_and_test_output(module, sample_inputs[i])
711+
695712

696713
class TestQNNFloatingPointModel(TestQNN):
697714
# TODO: refactor to support different backends
@@ -1396,6 +1413,12 @@ def test_qnn_backend_log(self):
13961413
module = self.get_qdq_module(module, sample_input)
13971414
self.lower_module_and_test_output(module, sample_input)
13981415

1416+
def test_qnn_backend_logical_not(self):
1417+
module = LogicalNot() # noqa: F405
1418+
sample_input = (torch.rand([1, 2, 3, 4]),)
1419+
module = self.get_qdq_module(module, sample_input)
1420+
self.lower_module_and_test_output(module, sample_input)
1421+
13991422
def test_qnn_backend_log_softmax(self):
14001423
module = LogSoftmax() # noqa: F405
14011424
sample_input = (torch.randn([1, 4, 8, 8]),)
@@ -1609,6 +1632,19 @@ def test_qnn_backend_view(self):
16091632
module = self.get_qdq_module(module, sample_input)
16101633
self.lower_module_and_test_output(module, sample_input)
16111634

1635+
def test_qnn_backend_where(self):
1636+
modules = [
1637+
Where(), # noqa: F405
1638+
WhereConstant(torch.randn(3, 2), torch.randn(3, 2)), # noqa: F405
1639+
]
1640+
sample_inputs = [
1641+
(torch.randn(3, 2), torch.randn(3, 2), torch.randn(3, 2)),
1642+
(torch.randn(3, 2),),
1643+
]
1644+
for i, module in enumerate(modules):
1645+
module = self.get_qdq_module(module, sample_inputs[i])
1646+
self.lower_module_and_test_output(module, sample_inputs[i])
1647+
16121648

16131649
class TestQNNQuantizedModel(TestQNN):
16141650
# TODO: refactor to support different backends

0 commit comments

Comments
 (0)