From 8f3b16ffe9a7d00958eed2e90037fc4e2aeb511a Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Mon, 31 Mar 2025 09:09:15 -0700 Subject: [PATCH] [ET-VK] Efficient tiled int8 matmul ## Context Introduce a optimized tiled implementation for computing the weight int8-quantized linear operation. This implementation takes advantage of the following principles to squeeze out performance: * Compute an output tile with each thread, rather than a single output element. This allows for better memory re-use of loaded input tensor data. * Compute the output tile by iteratively loading tiles of the input matrices, caching them in registers, and then performing the `fma` accumulations to obtain a partial output. By splitting the data loading and computation into distinct steps, the GPU is able to perform latency hiding more effectively, i.e. switching to a warp that needs to perform compute when the current warp is waiting on data load * Use a work group size of `{N, 1, 1}`. This makes it so that all the threads in a work group load the same row of the input matrx, and consecutive columns of the weight matrix. This way, the row of the input is kept hot in the cache, and accesses to the weight matrix can be coalesced due to the previous diff un-transposing the weight matrix. Differential Revision: [D72066587](https://our.internmc.facebook.com/intern/diff/D72066587/) [ghstack-poisoned] --- .../graph/ops/glsl/q_8w_linear_optimized.glsl | 212 ------------------ .../graph/ops/glsl/q_8w_linear_optimized.yaml | 35 --- .../graph/ops/glsl/q_8w_linear_tiled.glsl | 92 ++++++++ .../graph/ops/glsl/q_8w_linear_tiled.yaml | 18 ++ .../graph/ops/impl/QuantizedLinear.cpp | 137 +++++------ 5 files changed, 184 insertions(+), 310 deletions(-) delete mode 100644 backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.glsl delete mode 100644 backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.yaml diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.glsl b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.glsl deleted file mode 100644 index b8d7622f94d..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.glsl +++ /dev/null @@ -1,212 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} - -#define VEC4_T ${texel_load_type(DTYPE, STORAGE)} -#define FLOAT_T ${buffer_scalar_type(DTYPE)} - -${define_active_storage_type(STORAGE)} - -${define_required_extensions(DTYPE)} -$if STORAGE == "buffer": - ${define_required_extensions("int8")} - - -$if BATCH_MODE: - #define BATCH_MODE - -#define TILE_ROWS ${TILE_ROWS} -#define FOUR 4 - -// we avoid mat4 and vec4 usage here as they compile to much less efficient -// SPIR-V -struct FloatMatrix_2d { - float data[TILE_ROWS][FOUR]; -}; - -struct FloatMatrix_3d { - float data[TILE_ROWS][FOUR][FOUR]; -}; - -#ifdef BATCH_MODE - #define FloatMatrix FloatMatrix_3d -#else - #define FloatMatrix FloatMatrix_2d -#endif - -#include "indexing_utils.h" - -layout(std430) buffer; - -${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)} -${layout_declare_tensor(1, "r", "t_mat1", DTYPE, STORAGE)} -${layout_declare_tensor(2, "r", "t_qmat2", "int8", STORAGE)} -${layout_declare_tensor(3, "r", "t_scales", DTYPE, STORAGE)} - -$if STORAGE == "buffer": - ${layout_declare_ubo(4, "ivec4", "out_sizes")} - ${layout_declare_ubo(5, "ivec4", "out_strides")} - ${layout_declare_ubo(6, "int", "out_numel")} - ${layout_declare_ubo(7, "ivec4", "mat1_sizes")} - ${layout_declare_ubo(8, "ivec4", "mat1_strides")} - ${layout_declare_ubo(9, "ivec4", "qmat2_strides")} - ${layout_declare_ubo(10, "ivec4", "scales_strides")} -$else: - ${layout_declare_ubo(4, "ivec3", "out_limits")} - ${layout_declare_ubo(5, "ivec4", "mat1_sizes")} - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -// This header file must be defined after the layout descriptors have been -// declared because the functions in the header assume some variables have been -// declared as layout descriptors. - -#ifdef USING_BUFFER - -#ifndef FLOAT_T -#define FLOAT_T float -#endif - -FLOAT_T q_8w_linear(const ivec4 out_idx, const int K) { - const FLOAT_T scale = t_scales[out_idx.x]; - - FLOAT_T outval = FLOAT_T(0.0); - - // Initial mat1 tensor idx will be (0, out_idx.y, out_idx.z, 0) - int mat1_offset = out_idx.y * mat1_strides.y + out_idx.z * qmat2_strides.z; - // Initial qmat2 tensor idx wil be (0, out_idx.x, 0, 0); note that the qmat2 - // tensor is transposed - int qmat2_offset = out_idx.x * qmat2_strides.y; - - // TODO(ssjia): optimize memory access pattern by traversing K in inner loop - for (int i = 0; i < K; i++) { - const FLOAT_T mat1_val = t_mat1[mat1_offset]; - const FLOAT_T mat2_val = t_qmat2[qmat2_offset] * scale; - - outval += mat1_val * mat2_val; - - mat1_offset++; - qmat2_offset++; - } - - return outval; -} - -void main() { - const int out_bufi = int(gl_GlobalInvocationID.x); - if (out_bufi >= out_numel) { - return; - } - - const ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, 0); - - t_out[out_bufi] = q_8w_linear(out_tidx, mat1_sizes.x); -} - -#else // USING_TEXTURE -FloatMatrix q_8w_linear_optimized(const ivec3 out_idx_tl) { - FloatMatrix results; - for (int i = 0; i < TILE_ROWS; i++) { - for (int j = 0; j < FOUR; j++) { -#ifdef BATCH_MODE - for (int k = 0; k < FOUR; k++) { - results.data[i][j][k] = 0.0f; - } -#else - results.data[i][j] = 0.0f; -#endif // BATCH_MODE - } - } - - VEC4_T im_mat1_partial_load[TILE_ROWS]; - VEC4_T im_mat2_partial_load[FOUR]; - -#ifdef BATCH_MODE - for (int batch_idx = 0; batch_idx < FOUR; batch_idx++) { - if (out_idx_tl.z + batch_idx >= out_limits.z) { - break; - } -#endif - for (int k = 0; k < mat1_sizes.x; k++) { - for (int r = 0; r < TILE_ROWS; r++) { - ivec3 mat1_pos = ivec3(k, out_idx_tl.y * TILE_ROWS + r, 0); -#ifdef BATCH_MODE - mat1_pos[2] = out_idx_tl.z + batch_idx; -#endif - - im_mat1_partial_load[r] = texelFetch(t_mat1, mat1_pos, 0); - } - - for (int r = 0; r < FOUR; ++r) { - ivec3 qmat2_pos = ivec3(k, FOUR * out_idx_tl.x + r, 0); - - im_mat2_partial_load[r] = texelFetch(t_qmat2, qmat2_pos, 0); - } - - vec4 scales = texelFetch(t_scales, ivec3(out_idx_tl.x, 0, 0), 0); - - // perform partial dot products and add partial result to results - for (int out_row = 0; out_row < TILE_ROWS; out_row++) { - for (int out_col = 0; out_col < FOUR; out_col++) { -#ifdef BATCH_MODE - results.data[out_row][out_col][batch_idx] += -#else - results.data[out_row][out_col] += -#endif - dot(im_mat1_partial_load[out_row], - im_mat2_partial_load[out_col] * scales[out_col]); - } - } - } -#ifdef BATCH_MODE - } -#endif - return results; -} - -void main() { - const ivec3 out_idx = ivec3(gl_GlobalInvocationID); - if (any(greaterThanEqual(out_idx, out_limits))) { - return; - } - - FloatMatrix results = q_8w_linear_optimized(out_idx); - - ivec3 out_pos = ivec3( - out_idx.x, - out_idx.y * TILE_ROWS, -#ifdef BATCH_MODE - out_idx.z * 4 -#else - out_idx.z -#endif -); - - for (int idx_c = 0; idx_c < TILE_ROWS; idx_c++, out_pos[1]++) { - out_pos.x = out_idx.x; - $if BATCH_MODE: - for (int idx_r = 0; idx_r < FOUR; idx_r++, out_pos[0]++) { - write_texel(t_out, out_pos, VEC4_T( - results.data[idx_c][idx_r][0], - results.data[idx_c][idx_r][1], - results.data[idx_c][idx_r][2], - results.data[idx_c][idx_r][3])); - } - $else: - write_texel(t_out, out_pos, VEC4_T( - results.data[idx_c][0], - results.data[idx_c][1], - results.data[idx_c][2], - results.data[idx_c][3])); - } -} - -#endif diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.yaml b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.yaml deleted file mode 100644 index 52bebf90125..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.yaml +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -q_8w_linear_optimized: - parameter_names_with_default_values: - DTYPE: float - STORAGE: texture3d - MAT1_PACKING: W_packed - MAT2_PACKING: W_packed - BATCH_MODE: false - TILE_ROWS: 4 - generate_variant_forall: - TILE_ROWS: - - VALUE: 4 - SUFFIX: tile_row_4 - - VALUE: 2 - SUFFIX: tile_row_2 - DTYPE: - - VALUE: float - - VALUE: half - STORAGE: - - VALUE: texture3d - - VALUE: buffer - shader_variants: - - NAME: q_8w_linear_optimized_W_packed_W_packed - - NAME: q_8w_linear_optimized_W_packed_H_packed - MAT2_PACKING: H_packed - - NAME: batch_q_8w_linear_optimized_W_packed_W_packed - BATCH_MODE: true - - NAME: batch_q_8w_linear_optimized_W_packed_H_packed - MAT2_PACKING: H_packed - BATCH_MODE: true diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.glsl new file mode 100644 index 00000000000..c3bd9f41af9 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.glsl @@ -0,0 +1,92 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define T ${buffer_scalar_type(DTYPE)} +#define VEC4_T ${buffer_gvec_type(DTYPE, 4)} + +#define TILE_ROWS ${TILE_ROWS} + +${define_required_extensions(DTYPE)} + +$if STORAGE == "buffer": + ${define_required_extensions("int8")} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight", "int8", STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_scales", DTYPE, STORAGE, is_scalar_array=False)} + + +layout(push_constant) uniform restrict Block { + ivec4 out_sizes; + ivec4 in_sizes; + ivec4 weight_sizes; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const uint out_row = gl_GlobalInvocationID.y * TILE_ROWS; + const uint out_col = gl_GlobalInvocationID.x << 2; + + if (out_col >= out_sizes.x || out_row >= out_sizes.y) { + return; + } + + VEC4_T a[TILE_ROWS]; + VEC4_T b[4]; + VEC4_T c[TILE_ROWS]; + + $if STORAGE == "buffer": + const VEC4_T scales = VEC4_T(t_scales[out_col >> 2]); + $else: + const VEC4_T scales = VEC4_T(texelFetch(t_scales, ivec3(out_col >> 2, 0, 0), 0)); + + [[unroll]] for (int i = 0; i < TILE_ROWS; ++i) { + c[i] = VEC4_T(0.0); + } + + for (int pos = 0; pos < in_sizes.x; pos += 4) { + // Preload weight tensor + [[unroll]] for (int i = 0; i < 4; i++) { + $if STORAGE == "buffer": + b[i] = t_weight[((pos + i) * B_sizes.x + out_col) >> 2]; + $else: + b[i] = VEC4_T(texelFetch(t_weight, ivec3(out_col >> 2, pos + i, 0), 0)); + } + + // Preload input tensor + [[unroll]] for (int i = 0; i < TILE_ROWS; i++) { + $if STORAGE == "buffer": + a[i] = t_in[((out_row + i) * in_sizes.x + (pos)) >> 2]; + $else: + a[i] = VEC4_T(texelFetch(t_in, ivec3(pos >> 2, out_row + i, 0), 0)); + } + + // Compute partial output + [[unroll]] for (int i = 0; i < TILE_ROWS; ++i) { + c[i] += a[i].x * b[0] + a[i].y * b[1] + a[i].z * b[2] + a[i].w * b[3]; + } + } + + // Store output tensor + [[unroll]] for (int i = 0; i < TILE_ROWS; ++i) { + $if STORAGE == "buffer": + t_out[((out_row + i) * out_sizes.x + out_col) >> 2] = c[i] * scales; + $else: + imageStore(t_out, ivec3(out_col >> 2, out_row + i, 0), c[i] * scales); + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.yaml new file mode 100644 index 00000000000..b01af47e179 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.yaml @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +q_8w_linear_tiled: + parameter_names_with_default_values: + DTYPE: float + STORAGE: texture3d + TILE_ROWS: 4 + shader_variants: + - NAME: q_8w_linear_tiled_o4x4_texture3d_float + STORAGE: texture3d + TILE_ROWS: 4 + - NAME: q_8w_linear_tiled_o4x6_texture3d_float + STORAGE: texture3d + TILE_ROWS: 6 diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp index 2011331ec38..f4f5c853ddd 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp @@ -160,100 +160,111 @@ void add_q_8w_linear_node( } } -void add_q_8w_linear_optimized_node( +void add_q_8w_linear_tiled_node( ComputeGraph& graph, const ValueRef mat1, const ValueRef q_mat2_data, const ValueRef scales_data, const ValueRef out) { - auto viewFn = VK_GET_OP_FN("aten.view_copy.default"); - ValueRef mat1_W_packed = mat1; - ValueRef out_W_packed = out; - if (!graph.is_buffer_storage(out) && - graph.packed_dim_of(mat1) != WHCN::kWidthDim) { - // Ensure mat1 is width packed - mat1_W_packed = graph.add_tensor_like(mat1, utils::kWidthPacked); - viewFn(graph, {mat1, graph.add_none(), mat1_W_packed}); - // Ensure out is packed correctly - out_W_packed = graph.add_tensor_like(out, utils::kWidthPacked); - } - utils::StorageType stype = graph.storage_type_of(out); - ValueRef q_mat2 = - prepack_standard(graph, q_mat2_data, stype, utils::kWidthPacked); + ValueRef q_mat2 = prepack_standard_hw_transposed( + graph, q_mat2_data, stype, utils::kWidthPacked); ValueRef scales = prepack_standard(graph, scales_data, stype, utils::kWidthPacked); - std::string kernel_name = "q_8w_linear_optimized"; + std::string kernel_name = "q_8w_linear_tiled"; kernel_name.reserve(kShaderNameReserve); - add_packed_dim_suffix(kernel_name, graph.packed_dim_of(mat1_W_packed)); - add_packed_dim_suffix(kernel_name, graph.packed_dim_of(q_mat2)); - std::vector mat1_sizes = graph.sizes_of(mat1_W_packed); - const int mat1_dims = mat1_sizes.size(); - if (mat1_dims == 3) { - kernel_name = "batch_" + kernel_name; - } - if (mat1_sizes.at(mat1_dims - 2) < 8) { - kernel_name += "_tile_row_2"; + std::vector mat1_sizes = graph.sizes_of(mat1); + const int64_t M = utils::val_at(-2, mat1_sizes); + int out_tile_nrows = 4; + if (M % 6 == 0) { + kernel_name += "_o4x6"; + out_tile_nrows = 6; } else { - kernel_name += "_tile_row_4"; + kernel_name += "_o4x4"; + out_tile_nrows = 4; } - add_dtype_suffix(kernel_name, graph.dtype_of(out_W_packed)); - add_storage_type_suffix(kernel_name, graph.storage_type_of(out_W_packed)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); - vkapi::ParamsBindList ubos({}); + utils::uvec3 global_wg_size = graph.logical_limits_of(out); + global_wg_size[1] = global_wg_size[1] / out_tile_nrows; - utils::uvec3 global_size; - utils::uvec3 local_size; - if (graph.is_buffer_storage(out)) { - ubos.append( - {graph.sizes_ubo(out_W_packed), - graph.strides_ubo(out_W_packed), - graph.numel_ubo(out_W_packed), - graph.sizes_ubo(mat1_W_packed), - graph.strides_ubo(mat1_W_packed), - graph.strides_ubo(q_mat2), - graph.strides_ubo(scales)}); - global_size = graph.create_global_wg_size(out_W_packed); - local_size = graph.create_local_wg_size(out_W_packed); - } else { - global_size = graph.logical_limits_of(out_W_packed); - ubos.append( - {graph.logical_limits_ubo(out_W_packed), - graph.sizes_ubo(mat1_W_packed)}); - if (mat1_sizes.at(mat1_dims - 2) < 8) { - global_size = global_size = utils::divup_vec(global_size, {1, 2, 1}); - } else { - global_size = utils::divup_vec(global_size, {1, 4, 1}); - } - local_size = {16, 3, 1}; - } + utils::uvec3 local_wg_size{64, 1, 1}; graph.execute_nodes().emplace_back(new DispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - global_size, - local_size, + global_wg_size, + local_wg_size, // Inputs and Outputs - {{out_W_packed, vkapi::MemoryAccessType::WRITE}, - {{mat1_W_packed, q_mat2, scales}, vkapi::MemoryAccessType::READ}}, + {{out, vkapi::kWrite}, {{mat1, q_mat2, scales}, vkapi::kRead}}, // Shader params buffers - ubos, + {}, // Specialization Constants - {}, // spec_vars, + {}, // Resizing Logic - resize_q_8w_linear_node)); + resize_q_8w_linear_node, + {}, + // Push Constants + {{graph.sizes_pc_of(out), graph.sizes_pc_of(mat1)}})); +} - if (!graph.is_buffer_storage(out)) { - viewFn(graph, {out_W_packed, graph.add_none(), out}); +bool can_use_tiled_impl( + ComputeGraph& graph, + const ValueRef mat1, + const ValueRef q_mat2_data, + const ValueRef scales_data, + const ValueRef out) { + (void)q_mat2_data; + (void)scales_data; + + // Check if mat1 is not a 3D tensor or that batches = 1 + // TODO(ssjia): Add support for batches in the tiled impl + if (graph.dim_of(mat1) == 3 && graph.size_at(-1, mat1) != 1) { + return false; + } + // Check that K is a multiple of 4 + if (graph.size_at(-1, mat1) % 4 != 0) { + return false; } + // Check that M is a multiple of 4 or 6 + if (graph.size_at(-2, mat1) % 4 != 0 && + graph.size_at(-2, mat1) % 6 != 0) { + return false; + } + // Check that the storage type is texture + // TODO(ssjia): Add support for buffer storage in the tiled impl + if (graph.storage_type_of(out) != utils::kTexture3D) { + return false; + } + // Check that the packed dim is the width dim + if (graph.packed_dim_of(mat1) != WHCN::kWidthDim) { + return false; + } + // Check that no special axis mapping is used for the input + // TODO(ssjia): Add support for non-standard axis mapping in the tiled impl + if (!graph.has_standard_axis_map(mat1)) { + return false; + } + // Check that no special axis mapping is used for the output + // TODO(ssjia): Add support for non-standard axis mapping in the tiled impl + if (!graph.has_standard_axis_map(out)) { + return false; + } + + return true; } void weight_int8pack_mm( ComputeGraph& graph, const std::vector& args) { check_q_8w_linear_args(graph, args[0], args[1], args[2], args[3]); + if (can_use_tiled_impl(graph, args[0], args[1], args[2], args[3])) { + return add_q_8w_linear_tiled_node( + graph, args[0], args[1], args[2], args[3]); + } return add_q_8w_linear_node(graph, args[0], args[1], args[2], args[3]); }