Skip to content

Commit b41c3a6

Browse files
authored
Merge branch 'main' into up-torchao
2 parents 153cdd4 + 2cce2db commit b41c3a6

15 files changed

+501
-264
lines changed

backends/vulkan/_passes/int4_weight_only_quantizer.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,6 @@ def _vk_replace_linear_int4(
118118
# Use custom vulkan linear layer as default
119119
linear_class: Type[torch.nn.Module] = VkWeightOnlyInt4Linear,
120120
copy_weights: bool = False,
121-
# Serves the same purpose as `tensor_dim_limit` in
122-
# executorch.backends.vulkan.partitioner.VulkanSupportedOperators
123-
feature_limit: int = 16384,
124121
):
125122
for name, child in module.named_children():
126123
if isinstance(child, torch.nn.Linear) and (
@@ -131,8 +128,6 @@ def _vk_replace_linear_int4(
131128
if (
132129
_check_linear_int4_k(child.in_features, groupsize, inner_k_tiles)
133130
or padding_allowed
134-
) and (
135-
child.out_features < feature_limit and child.in_features < feature_limit
136131
):
137132
new_linear = linear_class(
138133
child.in_features,
@@ -175,7 +170,6 @@ def __init__(
175170
inner_k_tiles: Optional[int] = 8,
176171
device: torch.device = torch.device("cpu"), # noqa
177172
precision: torch.dtype = torch.float32,
178-
feature_limit: int = 16384,
179173
) -> None:
180174
super().__init__()
181175
assert inner_k_tiles in [2, 4, 8]
@@ -186,9 +180,6 @@ def __init__(
186180
self.padding_allowed: bool = padding_allowed
187181
self.device: torch.device = device
188182
self.precision: torch.dtype = precision
189-
# Serves the same purpose as `tensor_dim_limit` in
190-
# executorch.backends.vulkan.partitioner.VulkanSupportedOperators
191-
self.feature_limit = feature_limit
192183

193184
@torch.no_grad()
194185
def _create_quantized_state_dict(
@@ -197,10 +188,7 @@ def _create_quantized_state_dict(
197188
cur_state_dict = model.state_dict()
198189
for fqn, mod in model.named_modules():
199190
# Add additional check to make sure features do not exceed feature limit
200-
if isinstance(mod, torch.nn.Linear) and (
201-
mod.out_features < self.feature_limit
202-
and mod.in_features < self.feature_limit
203-
):
191+
if isinstance(mod, torch.nn.Linear):
204192
out_features = mod.out_features
205193
in_features = mod.in_features
206194
logging.info(f"linear: {fqn}, in={in_features}, out={out_features}")

backends/vulkan/_passes/squeeze_unsqueeze_inputs.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,25 +27,38 @@ class SqueezeUnsqueezeInputs(ExportPass):
2727
exir_ops.edge.aten.gelu.default,
2828
}
2929

30+
def should_squeeze(self, op, shape: List[int]) -> bool: # pyre-ignore
31+
if len(shape) == 3:
32+
return shape[1] == 1 and shape[0] > 1
33+
if len(shape) == 4:
34+
# No need to squeeze if all dims are 1 except the width dim
35+
if all(dim == 1 for dim in shape[:-1]):
36+
return False
37+
# Otherwise, check for squeezable dim
38+
return 1 in shape[:-1]
39+
40+
# Prefer not to introduce additional orchestration ops by default
41+
return False
42+
3043
def call_operator(
3144
self,
3245
op, # pyre-ignore
3346
args: Tuple[Argument, ...],
3447
kwargs: Dict[str, Argument],
3548
meta: NodeMetadata,
3649
) -> ProxyValue:
37-
def _squeezable(shape: List[int]) -> bool:
38-
return len(shape) > 2 and 1 in shape
39-
4050
if op not in self._squeezable_ops:
4151
return super().call_operator(op, args, kwargs, meta)
42-
4352
# pyre-ignore[16]: `None` has no attribute `node`
4453
input_shape = args[0].node.meta["val"].shape
4554
output_shape = meta["val"].shape
46-
if not _squeezable(input_shape):
55+
56+
if not self.should_squeeze(op, input_shape):
4757
return super().call_operator(op, args, kwargs, meta)
4858

59+
def _squeezable(shape: List[int]) -> bool:
60+
return len(shape) > 2 and 1 in shape
61+
4962
# squeeze input tensor
5063
squeeze_shape = list(input_shape)
5164
while _squeezable(squeeze_shape):

backends/vulkan/op_registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,7 @@ def register_int8_mm_op(features: OpFeatures):
393393

394394
@update_features(exir_ops.edge.et_vk.linear_weight_int4.default)
395395
def register_int4_mm_op(features: OpFeatures):
396+
features.buffer_impl = True
396397
features.texture_impl = TextureImplFeatures(
397398
uses_axis_map=False,
398399
valid_packed_dims={PackedDim.WIDTH},
@@ -401,6 +402,7 @@ def register_int4_mm_op(features: OpFeatures):
401402
features.optimal_storage = VkStorageType.TEXTURE_3D
402403
features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED
403404
features.handles_own_prepacking = True
405+
features.skip_limits_check = {1}
404406
return features
405407

406408

backends/vulkan/runtime/api/containers/Tensor.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -497,9 +497,7 @@ vTensor::vTensor(
497497
VK_CHECK_COND(
498498
dim_order_is_valid(dim_order_), "computed dim order is invalid");
499499

500-
if (storage_type != utils::kBuffer) {
501-
set_logical_limits(storage_.image_extents_);
502-
}
500+
set_logical_limits(storage_.image_extents_);
503501
}
504502

505503
// NOLINTNEXTLINE
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
${define_required_extensions("uint8")}
14+
${define_required_extensions("int8")}
15+
16+
layout(std430) buffer;
17+
18+
${layout_declare_tensor(B, "w", "t_qmat2", "uint8", STORAGE, is_scalar_array=False)}
19+
${layout_declare_tensor(B, "r", "nchw_4x2", "uint8", "buffer")}
20+
21+
layout(push_constant) uniform restrict Block {
22+
ivec4 qmat2_sizes;
23+
};
24+
25+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
26+
27+
uint8_t get_first(const uint8_t packed) {
28+
return uint8_t((packed & 0xF0) >> 4);
29+
}
30+
31+
uint8_t get_second(const uint8_t packed) {
32+
return uint8_t(packed & 0x0F);
33+
}
34+
35+
uint8_t combine(const uint8_t first, const uint8_t second) {
36+
return uint8_t(first << 4 | second);
37+
}
38+
39+
/*
40+
* This shader packs the weight tensor into a texture.
41+
*
42+
* The original tensor has a (W, H) shape of (K / 2, N) and each scalar element
43+
* is a uint8_t, which contains 2 packed 4 bit uint values.
44+
*
45+
* The transform performed by this shader is to first transpose the tensor, so
46+
* the shape of the packed tensor becomes (N / 2, K). Then, the 4 bit integers
47+
* are re-packed in groups of 8. For each 4 uint8_t values, the "left" 4-bits
48+
* of each value contain the 0, 1, 2, 3 4-bit values, and the "right" 4-bits of
49+
* each value contain the 4, 5, 6, 7 4-bit values.
50+
*
51+
* As a concrete example, consider the following weight tensor. The | demarks
52+
* the packing boundary, so 1| 2 represents a single uint8_t value with 1 in the
53+
* leftmost 4 bits and 2 in the rightmost 4 bits.
54+
*
55+
* 1| 2, 3| 4, 5| 6, 7| 8,
56+
* 9|10, 11|12, 13|14, 15|16,
57+
* 17|18, 19|20, 21|22, 23|24,
58+
* 25|26, 27|28, 29|30, 31|32,
59+
* 33|34, 35|36, 37|38, 39|40,
60+
* 41|42, 43|44, 45|46, 47|48,
61+
* 49|50, 51|52, 53|54, 55|56,
62+
* 57|58, 59|60, 61|62, 63|64,
63+
*
64+
* After packing, the packed tensor would contain
65+
*
66+
* 1|33, 9|41, 17|49, 25|57,
67+
* 2|34, 10|42, 18|50, 26|58,
68+
* 3|35, 11|43, 19|51, 27|59,
69+
* 4|36, 12|44, 20|52, 28|60,
70+
* 5|37, 13|45, 21|53, 29|61,
71+
* 6|38, 14|46, 22|54, 30|62,
72+
* 7|39, 15|47, 23|55, 31|63,
73+
* 8|40, 16|48, 24|56, 32|64,
74+
*
75+
* The purpose of interleaving is to make it easier to extract the unpacked
76+
* values in order using the u8vec4 vectorized type. With the packing in place,
77+
* The 4-bit values can be extracted via
78+
*
79+
* u8vec4 packed;
80+
* u8vec4 vals_0123 = (packed & 0xF0) >> 4;
81+
* u8vec4 vals_4567 = (packed | 0x0F);
82+
*/
83+
void main() {
84+
// Each thread writes 2 output texels along the height axis
85+
ivec2 packed_pos = ivec2(
86+
gl_GlobalInvocationID.x,
87+
gl_GlobalInvocationID.y << 1);
88+
89+
// The packed tensor is width packed
90+
if ((packed_pos.x << 2) >= qmat2_sizes.x || packed_pos.y >= qmat2_sizes.y) {
91+
return;
92+
}
93+
94+
int out_col = packed_pos.x << 3;
95+
int out_row = packed_pos.y;
96+
97+
int in_col = out_row;
98+
int in_int8_col = in_col >> 1;
99+
int in_row = out_col;
100+
101+
int in_numrows = qmat2_sizes.x << 1;
102+
int in_numcols = qmat2_sizes.y;
103+
int in_num_int8_cols = qmat2_sizes.y >> 1;
104+
105+
uint8_t in_vals[8][2];
106+
for (int r = 0; r < 8; ++r) {
107+
if (in_row + r < in_numrows) {
108+
uint8_t in_val_packed = nchw_4x2[(in_row + r) * in_num_int8_cols + in_int8_col];
109+
in_vals[r][0] = get_first(in_val_packed);
110+
in_vals[r][1] = get_second(in_val_packed);
111+
} else {
112+
in_vals[r][0] = uint8_t(254);
113+
in_vals[r][1] = uint8_t(254);
114+
}
115+
}
116+
117+
u8vec4 out_tex_1 = u8vec4(
118+
combine(in_vals[0][0], in_vals[4][0]),
119+
combine(in_vals[1][0], in_vals[5][0]),
120+
combine(in_vals[2][0], in_vals[6][0]),
121+
combine(in_vals[3][0], in_vals[7][0]));
122+
123+
u8vec4 out_tex_2 = u8vec4(
124+
combine(in_vals[0][1], in_vals[4][1]),
125+
combine(in_vals[1][1], in_vals[5][1]),
126+
combine(in_vals[2][1], in_vals[6][1]),
127+
combine(in_vals[3][1], in_vals[7][1]));
128+
129+
$if STORAGE == "buffer":
130+
int stride = qmat2_sizes.x >> 2;
131+
t_qmat2[packed_pos.y * stride + packed_pos.x] = out_tex_1;
132+
t_qmat2[(packed_pos.y + 1) * stride + packed_pos.x] = out_tex_2;
133+
$else:
134+
imageStore(t_qmat2, ivec3(packed_pos.xy, 0), out_tex_1);
135+
imageStore(t_qmat2, ivec3(packed_pos.x, packed_pos.y + 1, 0), out_tex_2);
136+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
pack_int4_linear_weight_transposed_interleaved:
8+
parameter_names_with_default_values:
9+
STORAGE: texture3d
10+
shader_variants:
11+
- NAME: pack_int4_linear_weight_transposed_interleaved_texture3d
12+
- NAME: pack_int4_linear_weight_transposed_interleaved_buffer
13+
STORAGE: buffer

0 commit comments

Comments
 (0)