diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.glsl b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.glsl deleted file mode 100644 index b8d7622f94d..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.glsl +++ /dev/null @@ -1,212 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} - -#define VEC4_T ${texel_load_type(DTYPE, STORAGE)} -#define FLOAT_T ${buffer_scalar_type(DTYPE)} - -${define_active_storage_type(STORAGE)} - -${define_required_extensions(DTYPE)} -$if STORAGE == "buffer": - ${define_required_extensions("int8")} - - -$if BATCH_MODE: - #define BATCH_MODE - -#define TILE_ROWS ${TILE_ROWS} -#define FOUR 4 - -// we avoid mat4 and vec4 usage here as they compile to much less efficient -// SPIR-V -struct FloatMatrix_2d { - float data[TILE_ROWS][FOUR]; -}; - -struct FloatMatrix_3d { - float data[TILE_ROWS][FOUR][FOUR]; -}; - -#ifdef BATCH_MODE - #define FloatMatrix FloatMatrix_3d -#else - #define FloatMatrix FloatMatrix_2d -#endif - -#include "indexing_utils.h" - -layout(std430) buffer; - -${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)} -${layout_declare_tensor(1, "r", "t_mat1", DTYPE, STORAGE)} -${layout_declare_tensor(2, "r", "t_qmat2", "int8", STORAGE)} -${layout_declare_tensor(3, "r", "t_scales", DTYPE, STORAGE)} - -$if STORAGE == "buffer": - ${layout_declare_ubo(4, "ivec4", "out_sizes")} - ${layout_declare_ubo(5, "ivec4", "out_strides")} - ${layout_declare_ubo(6, "int", "out_numel")} - ${layout_declare_ubo(7, "ivec4", "mat1_sizes")} - ${layout_declare_ubo(8, "ivec4", "mat1_strides")} - ${layout_declare_ubo(9, "ivec4", "qmat2_strides")} - ${layout_declare_ubo(10, "ivec4", "scales_strides")} -$else: - ${layout_declare_ubo(4, "ivec3", "out_limits")} - ${layout_declare_ubo(5, "ivec4", "mat1_sizes")} - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -// This header file must be defined after the layout descriptors have been -// declared because the functions in the header assume some variables have been -// declared as layout descriptors. - -#ifdef USING_BUFFER - -#ifndef FLOAT_T -#define FLOAT_T float -#endif - -FLOAT_T q_8w_linear(const ivec4 out_idx, const int K) { - const FLOAT_T scale = t_scales[out_idx.x]; - - FLOAT_T outval = FLOAT_T(0.0); - - // Initial mat1 tensor idx will be (0, out_idx.y, out_idx.z, 0) - int mat1_offset = out_idx.y * mat1_strides.y + out_idx.z * qmat2_strides.z; - // Initial qmat2 tensor idx wil be (0, out_idx.x, 0, 0); note that the qmat2 - // tensor is transposed - int qmat2_offset = out_idx.x * qmat2_strides.y; - - // TODO(ssjia): optimize memory access pattern by traversing K in inner loop - for (int i = 0; i < K; i++) { - const FLOAT_T mat1_val = t_mat1[mat1_offset]; - const FLOAT_T mat2_val = t_qmat2[qmat2_offset] * scale; - - outval += mat1_val * mat2_val; - - mat1_offset++; - qmat2_offset++; - } - - return outval; -} - -void main() { - const int out_bufi = int(gl_GlobalInvocationID.x); - if (out_bufi >= out_numel) { - return; - } - - const ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, 0); - - t_out[out_bufi] = q_8w_linear(out_tidx, mat1_sizes.x); -} - -#else // USING_TEXTURE -FloatMatrix q_8w_linear_optimized(const ivec3 out_idx_tl) { - FloatMatrix results; - for (int i = 0; i < TILE_ROWS; i++) { - for (int j = 0; j < FOUR; j++) { -#ifdef BATCH_MODE - for (int k = 0; k < FOUR; k++) { - results.data[i][j][k] = 0.0f; - } -#else - results.data[i][j] = 0.0f; -#endif // BATCH_MODE - } - } - - VEC4_T im_mat1_partial_load[TILE_ROWS]; - VEC4_T im_mat2_partial_load[FOUR]; - -#ifdef BATCH_MODE - for (int batch_idx = 0; batch_idx < FOUR; batch_idx++) { - if (out_idx_tl.z + batch_idx >= out_limits.z) { - break; - } -#endif - for (int k = 0; k < mat1_sizes.x; k++) { - for (int r = 0; r < TILE_ROWS; r++) { - ivec3 mat1_pos = ivec3(k, out_idx_tl.y * TILE_ROWS + r, 0); -#ifdef BATCH_MODE - mat1_pos[2] = out_idx_tl.z + batch_idx; -#endif - - im_mat1_partial_load[r] = texelFetch(t_mat1, mat1_pos, 0); - } - - for (int r = 0; r < FOUR; ++r) { - ivec3 qmat2_pos = ivec3(k, FOUR * out_idx_tl.x + r, 0); - - im_mat2_partial_load[r] = texelFetch(t_qmat2, qmat2_pos, 0); - } - - vec4 scales = texelFetch(t_scales, ivec3(out_idx_tl.x, 0, 0), 0); - - // perform partial dot products and add partial result to results - for (int out_row = 0; out_row < TILE_ROWS; out_row++) { - for (int out_col = 0; out_col < FOUR; out_col++) { -#ifdef BATCH_MODE - results.data[out_row][out_col][batch_idx] += -#else - results.data[out_row][out_col] += -#endif - dot(im_mat1_partial_load[out_row], - im_mat2_partial_load[out_col] * scales[out_col]); - } - } - } -#ifdef BATCH_MODE - } -#endif - return results; -} - -void main() { - const ivec3 out_idx = ivec3(gl_GlobalInvocationID); - if (any(greaterThanEqual(out_idx, out_limits))) { - return; - } - - FloatMatrix results = q_8w_linear_optimized(out_idx); - - ivec3 out_pos = ivec3( - out_idx.x, - out_idx.y * TILE_ROWS, -#ifdef BATCH_MODE - out_idx.z * 4 -#else - out_idx.z -#endif -); - - for (int idx_c = 0; idx_c < TILE_ROWS; idx_c++, out_pos[1]++) { - out_pos.x = out_idx.x; - $if BATCH_MODE: - for (int idx_r = 0; idx_r < FOUR; idx_r++, out_pos[0]++) { - write_texel(t_out, out_pos, VEC4_T( - results.data[idx_c][idx_r][0], - results.data[idx_c][idx_r][1], - results.data[idx_c][idx_r][2], - results.data[idx_c][idx_r][3])); - } - $else: - write_texel(t_out, out_pos, VEC4_T( - results.data[idx_c][0], - results.data[idx_c][1], - results.data[idx_c][2], - results.data[idx_c][3])); - } -} - -#endif diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.yaml b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.yaml deleted file mode 100644 index 52bebf90125..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.yaml +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -q_8w_linear_optimized: - parameter_names_with_default_values: - DTYPE: float - STORAGE: texture3d - MAT1_PACKING: W_packed - MAT2_PACKING: W_packed - BATCH_MODE: false - TILE_ROWS: 4 - generate_variant_forall: - TILE_ROWS: - - VALUE: 4 - SUFFIX: tile_row_4 - - VALUE: 2 - SUFFIX: tile_row_2 - DTYPE: - - VALUE: float - - VALUE: half - STORAGE: - - VALUE: texture3d - - VALUE: buffer - shader_variants: - - NAME: q_8w_linear_optimized_W_packed_W_packed - - NAME: q_8w_linear_optimized_W_packed_H_packed - MAT2_PACKING: H_packed - - NAME: batch_q_8w_linear_optimized_W_packed_W_packed - BATCH_MODE: true - - NAME: batch_q_8w_linear_optimized_W_packed_H_packed - MAT2_PACKING: H_packed - BATCH_MODE: true diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.glsl new file mode 100644 index 00000000000..c3bd9f41af9 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.glsl @@ -0,0 +1,92 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define T ${buffer_scalar_type(DTYPE)} +#define VEC4_T ${buffer_gvec_type(DTYPE, 4)} + +#define TILE_ROWS ${TILE_ROWS} + +${define_required_extensions(DTYPE)} + +$if STORAGE == "buffer": + ${define_required_extensions("int8")} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight", "int8", STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_scales", DTYPE, STORAGE, is_scalar_array=False)} + + +layout(push_constant) uniform restrict Block { + ivec4 out_sizes; + ivec4 in_sizes; + ivec4 weight_sizes; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const uint out_row = gl_GlobalInvocationID.y * TILE_ROWS; + const uint out_col = gl_GlobalInvocationID.x << 2; + + if (out_col >= out_sizes.x || out_row >= out_sizes.y) { + return; + } + + VEC4_T a[TILE_ROWS]; + VEC4_T b[4]; + VEC4_T c[TILE_ROWS]; + + $if STORAGE == "buffer": + const VEC4_T scales = VEC4_T(t_scales[out_col >> 2]); + $else: + const VEC4_T scales = VEC4_T(texelFetch(t_scales, ivec3(out_col >> 2, 0, 0), 0)); + + [[unroll]] for (int i = 0; i < TILE_ROWS; ++i) { + c[i] = VEC4_T(0.0); + } + + for (int pos = 0; pos < in_sizes.x; pos += 4) { + // Preload weight tensor + [[unroll]] for (int i = 0; i < 4; i++) { + $if STORAGE == "buffer": + b[i] = t_weight[((pos + i) * B_sizes.x + out_col) >> 2]; + $else: + b[i] = VEC4_T(texelFetch(t_weight, ivec3(out_col >> 2, pos + i, 0), 0)); + } + + // Preload input tensor + [[unroll]] for (int i = 0; i < TILE_ROWS; i++) { + $if STORAGE == "buffer": + a[i] = t_in[((out_row + i) * in_sizes.x + (pos)) >> 2]; + $else: + a[i] = VEC4_T(texelFetch(t_in, ivec3(pos >> 2, out_row + i, 0), 0)); + } + + // Compute partial output + [[unroll]] for (int i = 0; i < TILE_ROWS; ++i) { + c[i] += a[i].x * b[0] + a[i].y * b[1] + a[i].z * b[2] + a[i].w * b[3]; + } + } + + // Store output tensor + [[unroll]] for (int i = 0; i < TILE_ROWS; ++i) { + $if STORAGE == "buffer": + t_out[((out_row + i) * out_sizes.x + out_col) >> 2] = c[i] * scales; + $else: + imageStore(t_out, ivec3(out_col >> 2, out_row + i, 0), c[i] * scales); + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.yaml new file mode 100644 index 00000000000..b01af47e179 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.yaml @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +q_8w_linear_tiled: + parameter_names_with_default_values: + DTYPE: float + STORAGE: texture3d + TILE_ROWS: 4 + shader_variants: + - NAME: q_8w_linear_tiled_o4x4_texture3d_float + STORAGE: texture3d + TILE_ROWS: 4 + - NAME: q_8w_linear_tiled_o4x6_texture3d_float + STORAGE: texture3d + TILE_ROWS: 6 diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp index 2011331ec38..f4f5c853ddd 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp @@ -160,100 +160,111 @@ void add_q_8w_linear_node( } } -void add_q_8w_linear_optimized_node( +void add_q_8w_linear_tiled_node( ComputeGraph& graph, const ValueRef mat1, const ValueRef q_mat2_data, const ValueRef scales_data, const ValueRef out) { - auto viewFn = VK_GET_OP_FN("aten.view_copy.default"); - ValueRef mat1_W_packed = mat1; - ValueRef out_W_packed = out; - if (!graph.is_buffer_storage(out) && - graph.packed_dim_of(mat1) != WHCN::kWidthDim) { - // Ensure mat1 is width packed - mat1_W_packed = graph.add_tensor_like(mat1, utils::kWidthPacked); - viewFn(graph, {mat1, graph.add_none(), mat1_W_packed}); - // Ensure out is packed correctly - out_W_packed = graph.add_tensor_like(out, utils::kWidthPacked); - } - utils::StorageType stype = graph.storage_type_of(out); - ValueRef q_mat2 = - prepack_standard(graph, q_mat2_data, stype, utils::kWidthPacked); + ValueRef q_mat2 = prepack_standard_hw_transposed( + graph, q_mat2_data, stype, utils::kWidthPacked); ValueRef scales = prepack_standard(graph, scales_data, stype, utils::kWidthPacked); - std::string kernel_name = "q_8w_linear_optimized"; + std::string kernel_name = "q_8w_linear_tiled"; kernel_name.reserve(kShaderNameReserve); - add_packed_dim_suffix(kernel_name, graph.packed_dim_of(mat1_W_packed)); - add_packed_dim_suffix(kernel_name, graph.packed_dim_of(q_mat2)); - std::vector mat1_sizes = graph.sizes_of(mat1_W_packed); - const int mat1_dims = mat1_sizes.size(); - if (mat1_dims == 3) { - kernel_name = "batch_" + kernel_name; - } - if (mat1_sizes.at(mat1_dims - 2) < 8) { - kernel_name += "_tile_row_2"; + std::vector mat1_sizes = graph.sizes_of(mat1); + const int64_t M = utils::val_at(-2, mat1_sizes); + int out_tile_nrows = 4; + if (M % 6 == 0) { + kernel_name += "_o4x6"; + out_tile_nrows = 6; } else { - kernel_name += "_tile_row_4"; + kernel_name += "_o4x4"; + out_tile_nrows = 4; } - add_dtype_suffix(kernel_name, graph.dtype_of(out_W_packed)); - add_storage_type_suffix(kernel_name, graph.storage_type_of(out_W_packed)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); - vkapi::ParamsBindList ubos({}); + utils::uvec3 global_wg_size = graph.logical_limits_of(out); + global_wg_size[1] = global_wg_size[1] / out_tile_nrows; - utils::uvec3 global_size; - utils::uvec3 local_size; - if (graph.is_buffer_storage(out)) { - ubos.append( - {graph.sizes_ubo(out_W_packed), - graph.strides_ubo(out_W_packed), - graph.numel_ubo(out_W_packed), - graph.sizes_ubo(mat1_W_packed), - graph.strides_ubo(mat1_W_packed), - graph.strides_ubo(q_mat2), - graph.strides_ubo(scales)}); - global_size = graph.create_global_wg_size(out_W_packed); - local_size = graph.create_local_wg_size(out_W_packed); - } else { - global_size = graph.logical_limits_of(out_W_packed); - ubos.append( - {graph.logical_limits_ubo(out_W_packed), - graph.sizes_ubo(mat1_W_packed)}); - if (mat1_sizes.at(mat1_dims - 2) < 8) { - global_size = global_size = utils::divup_vec(global_size, {1, 2, 1}); - } else { - global_size = utils::divup_vec(global_size, {1, 4, 1}); - } - local_size = {16, 3, 1}; - } + utils::uvec3 local_wg_size{64, 1, 1}; graph.execute_nodes().emplace_back(new DispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - global_size, - local_size, + global_wg_size, + local_wg_size, // Inputs and Outputs - {{out_W_packed, vkapi::MemoryAccessType::WRITE}, - {{mat1_W_packed, q_mat2, scales}, vkapi::MemoryAccessType::READ}}, + {{out, vkapi::kWrite}, {{mat1, q_mat2, scales}, vkapi::kRead}}, // Shader params buffers - ubos, + {}, // Specialization Constants - {}, // spec_vars, + {}, // Resizing Logic - resize_q_8w_linear_node)); + resize_q_8w_linear_node, + {}, + // Push Constants + {{graph.sizes_pc_of(out), graph.sizes_pc_of(mat1)}})); +} - if (!graph.is_buffer_storage(out)) { - viewFn(graph, {out_W_packed, graph.add_none(), out}); +bool can_use_tiled_impl( + ComputeGraph& graph, + const ValueRef mat1, + const ValueRef q_mat2_data, + const ValueRef scales_data, + const ValueRef out) { + (void)q_mat2_data; + (void)scales_data; + + // Check if mat1 is not a 3D tensor or that batches = 1 + // TODO(ssjia): Add support for batches in the tiled impl + if (graph.dim_of(mat1) == 3 && graph.size_at(-1, mat1) != 1) { + return false; + } + // Check that K is a multiple of 4 + if (graph.size_at(-1, mat1) % 4 != 0) { + return false; } + // Check that M is a multiple of 4 or 6 + if (graph.size_at(-2, mat1) % 4 != 0 && + graph.size_at(-2, mat1) % 6 != 0) { + return false; + } + // Check that the storage type is texture + // TODO(ssjia): Add support for buffer storage in the tiled impl + if (graph.storage_type_of(out) != utils::kTexture3D) { + return false; + } + // Check that the packed dim is the width dim + if (graph.packed_dim_of(mat1) != WHCN::kWidthDim) { + return false; + } + // Check that no special axis mapping is used for the input + // TODO(ssjia): Add support for non-standard axis mapping in the tiled impl + if (!graph.has_standard_axis_map(mat1)) { + return false; + } + // Check that no special axis mapping is used for the output + // TODO(ssjia): Add support for non-standard axis mapping in the tiled impl + if (!graph.has_standard_axis_map(out)) { + return false; + } + + return true; } void weight_int8pack_mm( ComputeGraph& graph, const std::vector& args) { check_q_8w_linear_args(graph, args[0], args[1], args[2], args[3]); + if (can_use_tiled_impl(graph, args[0], args[1], args[2], args[3])) { + return add_q_8w_linear_tiled_node( + graph, args[0], args[1], args[2], args[3]); + } return add_q_8w_linear_node(graph, args[0], args[1], args[2], args[3]); }