Skip to content

Commit 05fcdb6

Browse files
Qualcomm AI Engine Direct - oss model enablement (EfficientSAM)
- e2e script for https://github.com/yformer/EfficientSAM - Fastvit breakage fix - Passes order correction - Add support for cum_sum
1 parent ade2e3c commit 05fcdb6

File tree

18 files changed

+739
-13
lines changed

18 files changed

+739
-13
lines changed

backends/qualcomm/_passes/layout_transform.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class LayoutTransform(ExportPass):
5454
exir_ops.edge.aten.ceil.default,
5555
exir_ops.edge.aten.clamp.default,
5656
exir_ops.edge.aten.constant_pad_nd.default,
57+
exir_ops.edge.aten.cumsum.default,
5758
exir_ops.edge.aten.div.Tensor,
5859
exir_ops.edge.aten.elu.default,
5960
exir_ops.edge.aten.eq.Tensor,

backends/qualcomm/_passes/recompose_pixel_unshuffle.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,11 @@ def call(self, graph_module: torch.fx.GraphModule):
4545
continue
4646

4747
view_node = premute_node.args[0]
48-
if any(
49-
[
50-
view_node.op != "call_function",
51-
view_node.target != self.view_target,
52-
len(view_node.args[1]) != 6,
53-
len(premute_node.args[1]) != 6,
54-
]
48+
if (
49+
view_node.op != "call_function"
50+
or view_node.target != self.view_target
51+
or len(view_node.args[1]) != 6
52+
or len(premute_node.args[1]) != 6
5553
):
5654
continue
5755

backends/qualcomm/_passes/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def get_passes_dependency_for_capture_program():
107107
ConvertToLinear: [RecomposePixelUnshuffle],
108108
DecomposeAny: [RemoveRedundancy],
109109
DecomposeLinalgVectorNorm: [RemoveRedundancy],
110-
ExpandBroadcastTensorShape: [RemoveRedundancy],
110+
ExpandBroadcastTensorShape: [ConstantI64toI32, TensorI64toI32],
111111
FoldQDQ: [AnnotateQuantAttrs, AnnotateDecomposed],
112112
LayoutTransform: [
113113
AnnotateQuantAttrs,

backends/qualcomm/builders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
op_clamp,
2121
op_conv2d,
2222
op_cos,
23+
op_cum_sum,
2324
op_depth_to_space,
2425
op_dequantize,
2526
op_div,
@@ -106,6 +107,7 @@
106107
op_clamp,
107108
op_conv2d,
108109
op_cos,
110+
op_cum_sum,
109111
op_depth_to_space,
110112
op_dequantize,
111113
op_div,

backends/qualcomm/builders/op_cos.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6-
76
from typing import Dict
87

98
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import cast, Dict
7+
8+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9+
10+
import numpy as np
11+
import torch
12+
from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA
13+
14+
from .node_visitor import NodeVisitor, register_node_visitor
15+
from .qnn_constants import OpCumulativeSum, QNN_OP_PACKAGE_NAME_QTI_AISW
16+
17+
18+
@register_node_visitor
19+
class CumulativeSum(NodeVisitor):
20+
target = ["aten.cumsum.default"]
21+
22+
def __init__(self, *args) -> None:
23+
super().__init__(*args)
24+
25+
def get_param(self, node, input_tensor):
26+
dim = node.args[1]
27+
28+
if dim < 0:
29+
dim = dim % len(input_tensor.shape)
30+
if QCOM_AXIS_ORDER in node.meta:
31+
dim = node.meta[QCOM_AXIS_ORDER].index(dim)
32+
33+
return cast(np.uint32, dim)
34+
35+
def define_node(
36+
self,
37+
node: torch.fx.Node,
38+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
39+
) -> PyQnnWrapper.PyQnnOpWrapper:
40+
input_node = node.args[0]
41+
input_tensor = self.get_tensor(input_node, node)
42+
input_tensor_wrapper = self.define_tensor(
43+
input_node,
44+
node,
45+
input_tensor,
46+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
47+
nodes_to_wrappers,
48+
)
49+
50+
dim = self.get_param(node, input_tensor)
51+
52+
output_tensor = self.get_tensor(node, node)
53+
output_tensor_wrapper = self.define_tensor(
54+
node,
55+
node,
56+
output_tensor,
57+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
58+
nodes_to_wrappers,
59+
)
60+
61+
cumsum_op = PyQnnWrapper.PyQnnOpWrapper(
62+
node.name,
63+
QNN_OP_PACKAGE_NAME_QTI_AISW,
64+
OpCumulativeSum.op_name,
65+
)
66+
cumsum_op.AddInputTensors([input_tensor_wrapper])
67+
cumsum_op.AddOutputTensors([output_tensor_wrapper])
68+
cumsum_op.AddScalarParam(
69+
OpCumulativeSum.param_axis,
70+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
71+
{QCOM_DATA: dim},
72+
)
73+
cumsum_op.AddScalarParam(
74+
OpCumulativeSum.param_exclusive,
75+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
76+
{QCOM_DATA: False},
77+
)
78+
cumsum_op.AddScalarParam(
79+
OpCumulativeSum.param_reverse,
80+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
81+
{QCOM_DATA: False},
82+
)
83+
84+
return cumsum_op

backends/qualcomm/builders/op_sin.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6-
76
from typing import Dict
87

98
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper

backends/qualcomm/builders/qnn_constants.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,14 @@ class OpConvert:
5050
op_name: str = "Convert"
5151

5252

53+
@dataclass(init=False, frozen=True)
54+
class OpCumulativeSum:
55+
op_name = "CumulativeSum"
56+
param_axis = "axis"
57+
param_exclusive = "exclusive"
58+
param_reverse = "reverse"
59+
60+
5361
@dataclass(init=False, frozen=True)
5462
class OpDepthToSpace:
5563
op_name: str = "DepthToSpace"

backends/qualcomm/quantizer/annotators.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -971,6 +971,11 @@ def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None
971971
)
972972

973973

974+
@register_annotator([torch.ops.aten.cumsum.default])
975+
def annotate_cumsum(node: Node, quantization_config: QuantizationConfig) -> None:
976+
annotate_single_in_single_out(node, quantization_config)
977+
978+
974979
@register_annotator([torch.ops.aten.linear.default])
975980
def annotate_linear(node: Node, quantization_config: QuantizationConfig) -> None:
976981
act_node = node.args[0]

backends/qualcomm/tests/models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,14 @@ def forward(self, x):
560560
return torch.cos(x)
561561

562562

563+
class CumSum(torch.nn.Module):
564+
def __init__(self):
565+
super().__init__()
566+
567+
def forward(self, x):
568+
return x.cumsum(dim=0)
569+
570+
563571
class Div(torch.nn.Module):
564572
def __init__(self):
565573
super().__init__()

0 commit comments

Comments
 (0)