Skip to content

Commit 8f0d797

Browse files
pytorchbotNathanael See
and
Nathanael See
authored
[ET-VK][int4] patch 4-bit linear op for ensuring w-packed in/out
Pull Request resolved: #8225 If the partitioner is using channels-packed setting for activations, then the checks will throw. Remove the checks and conditionally re-pack the input/output tensors if they are not width-packed. ghstack-source-id: 264952605 @exported-using-ghexport Differential Revision: [D68813946](https://our.internmc.facebook.com/intern/diff/D68813946/) --------- Co-authored-by: Nathanael See <[email protected]>
1 parent 8ec08f9 commit 8f0d797

File tree

1 file changed

+29
-8
lines changed

1 file changed

+29
-8
lines changed

backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -260,9 +260,6 @@ void check_q_4w_linear_args(
260260
const int group_size_val = graph.extract_scalar<int>(group_size);
261261
VK_CHECK_COND(K % group_size_val == 0);
262262

263-
VK_CHECK_COND(graph.packed_dim_of(mat1) == WHCN::kWidthDim);
264-
VK_CHECK_COND(graph.packed_dim_of(out) == WHCN::kWidthDim);
265-
266263
VK_CHECK_COND(graph.has_standard_axis_map(mat1));
267264
VK_CHECK_COND(graph.has_standard_axis_map(out));
268265
}
@@ -320,13 +317,32 @@ void add_q_4w_linear_node(
320317

321318
const uint32_t group_size_val = graph.extract_scalar<uint32_t>(group_size);
322319

320+
ValueRef mat1_W_packed = mat1;
321+
ValueRef out_W_packed = out;
322+
auto viewFn = VK_GET_OP_FN("aten.view_copy.default");
323+
// Create temporary tensors to store the width packed versions of mat1 and out
324+
TmpTensor mat1_tmp(
325+
&graph, graph.sizes_of(mat1), graph.dtype_of(mat1), utils::kWidthPacked);
326+
TmpTensor out_tmp(
327+
&graph, graph.sizes_of(out), graph.dtype_of(out), utils::kWidthPacked);
328+
if (storage_type == utils::kTexture3D) {
329+
if (!graph.is_buffer_storage(out) &&
330+
graph.packed_dim_of(mat1) != WHCN::kWidthDim) {
331+
// Ensure mat1 is width packed
332+
mat1_W_packed = mat1_tmp;
333+
viewFn(graph, {mat1, graph.add_none(), mat1_W_packed});
334+
// Ensure out is packed correctly
335+
out_W_packed = out_tmp;
336+
}
337+
}
338+
323339
vkapi::ParamsBindList ubos({});
324-
ubos.append(graph.logical_limits_ubo(out));
325-
ubos.append(graph.sizes_ubo(mat1));
340+
ubos.append(graph.logical_limits_ubo(out_W_packed));
341+
ubos.append(graph.sizes_ubo(mat1_W_packed));
326342
ubos.append(graph.strides_ubo(mat2));
327343
ubos.append(graph.strides_ubo(scales_and_zeros));
328344

329-
utils::uvec3 global_wg_size = graph.logical_limits_of(out);
345+
utils::uvec3 global_wg_size = graph.logical_limits_of(out_W_packed);
330346
utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size);
331347

332348
graph.execute_nodes().emplace_back(new DispatchNode(
@@ -335,15 +351,20 @@ void add_q_4w_linear_node(
335351
global_wg_size,
336352
local_wg_size,
337353
// Inputs and Outputs
338-
{{out, vkapi::MemoryAccessType::WRITE},
339-
{{mat1, mat2, scales_and_zeros}, vkapi::MemoryAccessType::READ}},
354+
{{out_W_packed, vkapi::MemoryAccessType::WRITE},
355+
{{mat1_W_packed, mat2, scales_and_zeros},
356+
vkapi::MemoryAccessType::READ}},
340357
// Shader params buffers
341358
ubos,
342359
// Specialization Constants
343360
{SV(group_size_val)},
344361
// Resizing Logic
345362
resize_q_4w_linear_node,
346363
{}));
364+
if (!graph.is_buffer_storage(out) &&
365+
graph.packed_dim_of(out) != WHCN::kWidthDim) {
366+
viewFn(graph, {out_W_packed, graph.add_none(), out});
367+
}
347368
}
348369

349370
void linear_weight_int4(

0 commit comments

Comments
 (0)