Skip to content

Commit 42412d7

Browse files
Qualcomm AI Engine Direct - oss model enablement (EfficientSAM)
- e2e script for https://github.com/yformer/EfficientSAM - Fastvit breakage fix - Add support for cum_sum - Add bicubic interpolate transform pass - Fix stack op
1 parent c9c5481 commit 42412d7

25 files changed

+788
-14
lines changed

backends/qualcomm/_passes/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .annotate_unbind import AnnotateUnbind
1010
from .convert_bmm_to_matmul import ConvertBmmToMatmul
1111
from .convert_conv1d_to_conv2d import ConvertConv1dToConv2d
12+
from .convert_upsample_bicubic2d import ConvertUpsampleBicubicWithBilinear
1213
from .decompose_any import DecomposeAny
1314
from .decompose_einsum import DecomposeEinsum
1415
from .decompose_expm1 import DecomposeExpM1
@@ -40,6 +41,7 @@
4041
ConvertBmmToMatmul,
4142
ConvertConv1dToConv2d,
4243
DecomposeAny,
44+
ConvertUpsampleBicubicWithBilinear,
4345
DecomposeEinsum,
4446
DecomposeExpM1,
4547
DecomposeLinalgVectorNorm,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from executorch.exir.dialects._ops import ops as exir_ops
7+
from executorch.exir.pass_base import ExportPass
8+
9+
10+
class ConvertUpsampleBicubicWithBilinear(ExportPass):
11+
"""
12+
Qnn does not support bicubic interpolation, so we need to convert it to bilinear.
13+
This pass will convert bicubic interpolation to bilinear interpolation.
14+
"""
15+
16+
bicubic_op_targets = {
17+
exir_ops.edge.aten.upsample_bicubic2d.vec,
18+
}
19+
upsample_bilinear_op = exir_ops.edge.aten.upsample_bilinear2d.default
20+
21+
def __init__(self):
22+
super(ConvertUpsampleBicubicWithBilinear, self).__init__()
23+
24+
def call_operator(self, op, args, kwargs, meta):
25+
if op not in self.bicubic_op_targets:
26+
return super().call_operator(op, args, kwargs, meta)
27+
return super().call_operator(self.upsample_bilinear_op, args[:-1], kwargs, meta)

backends/qualcomm/_passes/layout_transform.py

+1
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class LayoutTransform(ExportPass):
5555
exir_ops.edge.aten.ceil.default,
5656
exir_ops.edge.aten.clamp.default,
5757
exir_ops.edge.aten.constant_pad_nd.default,
58+
exir_ops.edge.aten.cumsum.default,
5859
exir_ops.edge.aten.div.Tensor,
5960
exir_ops.edge.aten.elu.default,
6061
exir_ops.edge.aten.eq.Tensor,

backends/qualcomm/_passes/qnn_pass_manager.py

+2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
AnnotateUnbind,
1515
ConvertBmmToMatmul,
1616
ConvertConv1dToConv2d,
17+
ConvertUpsampleBicubicWithBilinear,
1718
DecomposeAny,
1819
DecomposeEinsum,
1920
DecomposeExpM1,
@@ -74,6 +75,7 @@ def get_capture_program_passes():
7475
(AnnotateUnbind, True),
7576
(ConvertBmmToMatmul, True),
7677
(ConvertConv1dToConv2d, True),
78+
(ConvertUpsampleBicubicWithBilinear, False),
7779
(DecomposeAny, True),
7880
(ExpandBroadcastTensorShape, False),
7981
(FixedLinearKeepDim, True),

backends/qualcomm/_passes/recompose_pixel_unshuffle.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,11 @@ def call(self, graph_module: torch.fx.GraphModule):
4545
continue
4646

4747
view_node = premute_node.args[0]
48-
if any(
49-
[
50-
view_node.op != "call_function",
51-
view_node.target != self.view_target,
52-
len(view_node.args[1]) != 6,
53-
len(premute_node.args[1]) != 6,
54-
]
48+
if (
49+
view_node.op != "call_function"
50+
or view_node.target != self.view_target
51+
or len(view_node.args[1]) != 6
52+
or len(premute_node.args[1]) != 6
5553
):
5654
continue
5755

backends/qualcomm/_passes/utils.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def get_passes_dependency_for_capture_program():
7878
AnnotateUnbind,
7979
ConvertBmmToMatmul,
8080
ConvertConv1dToConv2d,
81+
ConvertUpsampleBicubicWithBilinear,
8182
DecomposeAny,
8283
DecomposeLinalgVectorNorm,
8384
ExpandBroadcastTensorShape,
@@ -96,18 +97,20 @@ def get_passes_dependency_for_capture_program():
9697
AnnotateQuantAttrs: [
9798
RecomposePixelUnshuffle,
9899
ConvertBmmToMatmul,
100+
ConvertUpsampleBicubicWithBilinear,
99101
RemoveRedundancy,
100102
],
101103
AnnotateStack: [RemoveRedundancy],
102104
AnnotateUnbind: [RemoveRedundancy],
103105
ConvertBmmToMatmul: [RecomposePixelUnshuffle],
104106
ConvertConv1dToConv2d: [FoldQDQ],
107+
ConvertUpsampleBicubicWithBilinear: [RemoveRedundancy],
105108
DecomposeAny: [RemoveRedundancy],
106109
DecomposeLinalgVectorNorm: [RemoveRedundancy],
107110
ExpandBroadcastTensorShape: [RemoveRedundancy],
108111
FixedLinearKeepDim: [FoldQDQ],
109112
FoldQDQ: [AnnotateQuantAttrs, AnnotateStack, AnnotateUnbind],
110-
I64toI32: [RemoveRedundancy],
113+
I64toI32: [ConvertUpsampleBicubicWithBilinear, RemoveRedundancy],
111114
LayoutTransform: [
112115
AnnotateQuantAttrs,
113116
ConvertConv1dToConv2d,

backends/qualcomm/builders/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
op_clamp,
2222
op_conv2d,
2323
op_cos,
24+
op_cum_sum,
2425
op_depth_to_space,
2526
op_dequantize,
2627
op_div,
@@ -108,6 +109,7 @@
108109
op_clamp,
109110
op_conv2d,
110111
op_cos,
112+
op_cum_sum,
111113
op_depth_to_space,
112114
op_dequantize,
113115
op_div,

backends/qualcomm/builders/op_cos.py

-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6-
76
from typing import Dict
87

98
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import cast, Dict
7+
8+
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
9+
10+
import numpy as np
11+
import torch
12+
from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA
13+
14+
from .node_visitor import NodeVisitor, register_node_visitor
15+
from .qnn_constants import OpCumulativeSum, QNN_OP_PACKAGE_NAME_QTI_AISW
16+
17+
18+
@register_node_visitor
19+
class CumulativeSum(NodeVisitor):
20+
target = ["aten.cumsum.default"]
21+
22+
def __init__(self, *args) -> None:
23+
super().__init__(*args)
24+
25+
def get_param(self, node, input_tensor):
26+
dim = node.args[1]
27+
28+
if dim < 0:
29+
dim = dim % len(input_tensor.shape)
30+
if QCOM_AXIS_ORDER in node.meta:
31+
dim = node.meta[QCOM_AXIS_ORDER].index(dim)
32+
33+
return cast(np.uint32, dim)
34+
35+
def define_node(
36+
self,
37+
node: torch.fx.Node,
38+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
39+
) -> PyQnnWrapper.PyQnnOpWrapper:
40+
input_node = node.args[0]
41+
input_tensor = self.get_tensor(input_node, node)
42+
input_tensor_wrapper = self.define_tensor(
43+
input_node,
44+
node,
45+
input_tensor,
46+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
47+
nodes_to_wrappers,
48+
)
49+
50+
dim = self.get_param(node, input_tensor)
51+
52+
output_tensor = self.get_tensor(node, node)
53+
output_tensor_wrapper = self.define_tensor(
54+
node,
55+
node,
56+
output_tensor,
57+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
58+
nodes_to_wrappers,
59+
)
60+
61+
cumsum_op = PyQnnWrapper.PyQnnOpWrapper(
62+
node.name,
63+
QNN_OP_PACKAGE_NAME_QTI_AISW,
64+
OpCumulativeSum.op_name,
65+
)
66+
cumsum_op.AddInputTensors([input_tensor_wrapper])
67+
cumsum_op.AddOutputTensors([output_tensor_wrapper])
68+
cumsum_op.AddScalarParam(
69+
OpCumulativeSum.param_axis,
70+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
71+
{QCOM_DATA: dim},
72+
)
73+
cumsum_op.AddScalarParam(
74+
OpCumulativeSum.param_exclusive,
75+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
76+
{QCOM_DATA: False},
77+
)
78+
cumsum_op.AddScalarParam(
79+
OpCumulativeSum.param_reverse,
80+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
81+
{QCOM_DATA: False},
82+
)
83+
84+
return cumsum_op

backends/qualcomm/builders/op_sin.py

-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6-
76
from typing import Dict
87

98
import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper

backends/qualcomm/builders/op_stack.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def define_node(
5151

5252
dim = 0 if len(node.args) == 1 else cast(int, node.args[1])
5353
if dim < 0:
54-
dim = dim % len(input_tensor.shape)
54+
dim = dim % len(output_tensor.shape)
5555
if QCOM_AXIS_ORDER in node.meta:
5656
dim = node.meta[QCOM_AXIS_ORDER].index(dim)
5757
stack_op = PyQnnWrapper.PyQnnOpWrapper(

backends/qualcomm/builders/qnn_constants.py

+8
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,14 @@ class OpConvert:
5757
op_name: str = "Convert"
5858

5959

60+
@dataclass(init=False, frozen=True)
61+
class OpCumulativeSum:
62+
op_name = "CumulativeSum"
63+
param_axis = "axis"
64+
param_exclusive = "exclusive"
65+
param_reverse = "reverse"
66+
67+
6068
@dataclass(init=False, frozen=True)
6169
class OpDepthToSpace:
6270
op_name: str = "DepthToSpace"

backends/qualcomm/partition/common_defs.py

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
exir_ops.edge.aten.clone.default,
1414
exir_ops.edge.aten.slice_scatter.default,
1515
exir_ops.edge.aten.copy.default,
16+
exir_ops.edge.aten.upsample_bicubic2d.vec,
1617
exir_ops.edge.quantized_decomposed.embedding_4bit.dtype,
1718
]
1819

backends/qualcomm/partition/utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def get_skip_decomp_table() -> List[torch._ops.OperatorBase]:
3939
torch.ops.aten.rms_norm.default,
4040
torch.ops.aten._safe_softmax.default,
4141
torch.ops.aten.stack.default,
42+
torch.ops.aten.upsample_bicubic2d.vec,
4243
# This request is ignored because it is in a blocklist. Refer to exir/program/_program.py
4344
# torch.ops.aten.unbind.int,
4445
torch.ops.pt2e_quant.quantize_affine.default,

backends/qualcomm/quantizer/annotators.py

+5
Original file line numberDiff line numberDiff line change
@@ -976,6 +976,11 @@ def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None
976976
)
977977

978978

979+
@register_annotator([torch.ops.aten.cumsum.default])
980+
def annotate_cumsum(node: Node, quantization_config: QuantizationConfig) -> None:
981+
annotate_single_in_single_out(node, quantization_config)
982+
983+
979984
@register_annotator([torch.ops.aten.linear.default])
980985
def annotate_linear(node: Node, quantization_config: QuantizationConfig) -> None:
981986
act_node = node.args[0]

backends/qualcomm/tests/models.py

+8
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,14 @@ def forward(self, x):
568568
return torch.cos(x)
569569

570570

571+
class CumSum(torch.nn.Module):
572+
def __init__(self):
573+
super().__init__()
574+
575+
def forward(self, x):
576+
return x.cumsum(dim=0)
577+
578+
571579
class Div(torch.nn.Module):
572580
def __init__(self):
573581
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

+51-1
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,11 @@ def test_qnn_backend_cos(self):
233233
sample_input = (torch.randn(2, 5, 1, 3),)
234234
self.lower_module_and_test_output(module, sample_input)
235235

236+
def test_qnn_backend_cumsum(self):
237+
module = CumSum() # noqa: F405
238+
sample_input = (torch.randn(4),)
239+
self.lower_module_and_test_output(module, sample_input)
240+
236241
def test_qnn_backend_einsum_outer_product(self):
237242
module = EinsumOuterProduct() # noqa: F405
238243
x = torch.randn(5)
@@ -1297,6 +1302,12 @@ def test_qnn_backend_cos(self):
12971302
module = self.get_qdq_module(module, sample_input)
12981303
self.lower_module_and_test_output(module, sample_input)
12991304

1305+
def test_qnn_backend_cumsum(self):
1306+
module = CumSum() # noqa: F405
1307+
sample_input = (torch.randn(4),)
1308+
module = self.get_qdq_module(module, sample_input)
1309+
self.lower_module_and_test_output(module, sample_input)
1310+
13001311
def test_qnn_backend_einsum_outer_product(self):
13011312
module = EinsumOuterProduct() # noqa: F405
13021313
x = torch.randn(5)
@@ -3537,7 +3548,6 @@ def test_conv_former(self):
35373548
self.assertGreaterEqual(msg["top_1"], 60)
35383549
self.assertGreaterEqual(msg["top_5"], 80)
35393550

3540-
@unittest.skip("bicubic resize is not supported")
35413551
def test_dino_v2(self):
35423552
if not self.required_envs([self.image_dataset]):
35433553
self.skipTest("missing required envs")
@@ -3573,6 +3583,46 @@ def test_dino_v2(self):
35733583
self.assertGreaterEqual(msg["top_1"], 70)
35743584
self.assertGreaterEqual(msg["top_5"], 85)
35753585

3586+
def test_efficientSAM(self):
3587+
if not self.required_envs(
3588+
[self.image_dataset, self.pretrained_weight, self.oss_repo]
3589+
):
3590+
self.skipTest("missing required envs")
3591+
cmds = [
3592+
"python",
3593+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/efficientSAM.py",
3594+
"--dataset",
3595+
self.image_dataset,
3596+
"--artifact",
3597+
self.artifact_dir,
3598+
"--build_folder",
3599+
self.build_folder,
3600+
"--device",
3601+
self.device,
3602+
"--model",
3603+
self.model,
3604+
"--oss_repo",
3605+
self.oss_repo,
3606+
"--pretrained_weight",
3607+
self.pretrained_weight,
3608+
"--ip",
3609+
self.ip,
3610+
"--port",
3611+
str(self.port),
3612+
]
3613+
if self.host:
3614+
cmds.extend(["--host", self.host])
3615+
3616+
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
3617+
with Listener((self.ip, self.port)) as listener:
3618+
conn = listener.accept()
3619+
p.communicate()
3620+
msg = json.loads(conn.recv())
3621+
if "Error" in msg:
3622+
self.fail(msg["Error"])
3623+
else:
3624+
self.assertGreaterEqual(msg["MIoU"], 0.55)
3625+
35763626
def test_esrgan(self):
35773627
if not self.required_envs():
35783628
self.skipTest("missing required envs")

backends/qualcomm/tests/utils.py

+2
Original file line numberDiff line numberDiff line change
@@ -438,12 +438,14 @@ def lower_module_and_test_output(
438438
skip_node_id_set: set = None,
439439
skip_node_op_set: set = None,
440440
dynamic_shapes: Dict = None,
441+
passes_job: collections.OrderedDict = None,
441442
):
442443
delegated_program = to_edge_transform_and_lower_to_qnn(
443444
module,
444445
sample_inputs,
445446
self.compiler_specs,
446447
dynamic_shapes=dynamic_shapes,
448+
passes_job=passes_job,
447449
skip_node_id_set=skip_node_id_set,
448450
skip_node_op_set=skip_node_op_set,
449451
)

0 commit comments

Comments
 (0)