Skip to content

Commit 13da2d5

Browse files
author
Nathanael See
committed
[ET-VK][int4] Wrap int4 linear calls with view_copy nodes to squeeze/unsqueeze inputs
Pull Request resolved: #8226 This is done automatically for full-precision linear/mm nodes in the graph at torch.export graph tracing time, but is not done for the int4 op. The new pass adds view_copy nodes, as there are subsequent passes which can fuse view_copy nodes if redundant, and convert view_copy nodes to squeeze/unsqueeze nodes. ghstack-source-id: 264952606 @exported-using-ghexport Differential Revision: [D69065866](https://our.internmc.facebook.com/intern/diff/D69065866/)
1 parent 65117d5 commit 13da2d5

File tree

5 files changed

+100
-1
lines changed

5 files changed

+100
-1
lines changed

backends/vulkan/_passes/TARGETS

+17-1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,21 @@ runtime.python_library(
3030
]
3131
)
3232

33+
runtime.python_library(
34+
name = "squeeze_int4_linear_inputs",
35+
srcs = [
36+
"squeeze_int4_linear_inputs.py",
37+
],
38+
visibility = [
39+
"//executorch/backends/...",
40+
],
41+
deps = [
42+
"//executorch/backends/vulkan:custom_ops_lib",
43+
"//executorch/exir:pass_base",
44+
"//executorch/exir/dialects:lib",
45+
]
46+
)
47+
3348
runtime.python_library(
3449
name = "remove_asserts",
3550
srcs = ["remove_asserts.py"],
@@ -99,6 +114,7 @@ runtime.python_library(
99114
":remove_asserts",
100115
":remove_local_scalar_dense",
101116
":remove_redundant_ops",
102-
":tag_memory_meta_pass"
117+
":squeeze_int4_linear_inputs",
118+
":tag_memory_meta_pass",
103119
]
104120
)

backends/vulkan/_passes/__init__.py

+12
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
19
from executorch.backends.vulkan._passes.insert_prepack_nodes import insert_prepack_nodes
210
from executorch.backends.vulkan._passes.int4_weight_only_quantizer import (
311
VkInt4WeightOnlyQuantizer,
@@ -12,6 +20,9 @@
1220
from executorch.backends.vulkan._passes.remove_redundant_ops import (
1321
RemoveRedundantOpsTransform,
1422
)
23+
from executorch.backends.vulkan._passes.squeeze_int4_linear_inputs import (
24+
SqueezeInt4LinearInputs,
25+
)
1526
from executorch.backends.vulkan._passes.tag_memory_meta_pass import TagMemoryMetaPass
1627

1728
__all__ = [
@@ -21,5 +32,6 @@
2132
"RemoveAssertsTransform",
2233
"RemoveLocalScalarDenseOpsTransform",
2334
"RemoveRedundantOpsTransform",
35+
"SqueezeInt4LinearInputs",
2436
"TagMemoryMetaPass",
2537
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
from typing import Dict, List, Tuple
10+
11+
import executorch.backends.vulkan.custom_ops_lib # noqa: needed to access vk op
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
14+
15+
from torch.fx.node import Argument
16+
17+
18+
class SqueezeInt4LinearInputs(ExportPass):
19+
def call_operator(
20+
self,
21+
op, # pyre-ignore
22+
args: Tuple[Argument, ...],
23+
kwargs: Dict[str, Argument],
24+
meta: NodeMetadata,
25+
) -> ProxyValue:
26+
def _squeezable(shape: List[int]) -> bool:
27+
return len(shape) > 2 and 1 in shape
28+
29+
if op != exir_ops.edge.et_vk.linear_weight_int4.default:
30+
return super().call_operator(op, args, kwargs, meta)
31+
32+
# pyre-ignore[16]: `None` has no attribute `node`
33+
input_shape = args[0].node.meta["val"].shape
34+
output_shape = meta["val"].shape
35+
if not _squeezable(input_shape):
36+
return super().call_operator(op, args, kwargs, meta)
37+
38+
# squeeze input tensor
39+
squeeze_shape = list(input_shape)
40+
while _squeezable(squeeze_shape):
41+
squeeze_shape.remove(1)
42+
43+
squeeze_out = super().call_operator(
44+
exir_ops.edge.aten.view_copy.default,
45+
(args[0], squeeze_shape),
46+
kwargs,
47+
meta,
48+
)
49+
# call linear on squeezed output
50+
new_args = (squeeze_out, *args[1:])
51+
linear_out = super().call_operator(
52+
op,
53+
new_args,
54+
kwargs,
55+
meta,
56+
)
57+
# unsqueeze output
58+
unsqueeze_shape = list(output_shape)
59+
return super().call_operator(
60+
exir_ops.edge.aten.view_copy.default,
61+
(linear_out, unsqueeze_shape),
62+
kwargs,
63+
meta,
64+
)

backends/vulkan/targets.bzl

+1
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@ def define_common_targets(is_fbcode = False):
328328
"//executorch/backends/transforms:fuse_dequant_linear",
329329
"//executorch/backends/transforms:fuse_view_copy",
330330
"//executorch/backends/transforms:remove_clone_ops",
331+
"//executorch/backends/transforms:view_copy_to_squeeze_unsqueeze",
331332
"//executorch/backends/vulkan/_passes:vulkan_passes",
332333
"//executorch/backends/vulkan/serialization:lib",
333334
"//executorch/exir/backend:backend_details",

backends/vulkan/vulkan_preprocess.py

+6
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,14 @@
1919
from executorch.backends.transforms.fuse_conv_with_clamp import FuseClampPass
2020
from executorch.backends.transforms.fuse_dequant_linear import FuseDequantLinearPass
2121
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform
22+
from executorch.backends.transforms.view_copy_to_squeeze_unsqueeze import (
23+
ViewCopyToSqueezeUnsqueezePass,
24+
)
2225
from executorch.backends.vulkan._passes import (
2326
insert_prepack_nodes,
2427
RemoveLocalScalarDenseOpsTransform,
2528
RemoveRedundantOpsTransform,
29+
SqueezeInt4LinearInputs,
2630
TagMemoryMetaPass,
2731
)
2832

@@ -149,7 +153,9 @@ def preprocess( # noqa: C901
149153
RemoveRedundantOpsTransform(),
150154
AddmmToLinearTransform(),
151155
FuseDequantLinearPass(),
156+
SqueezeInt4LinearInputs(),
152157
FuseViewCopyTransform(),
158+
ViewCopyToSqueezeUnsqueezePass(),
153159
FuseBatchNormWithConvPass(program),
154160
FuseClampPass(),
155161
],

0 commit comments

Comments
 (0)