From 03947a85ba9e660b1ba57630d15e3cf63f783a0b Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Thu, 6 Feb 2025 11:29:07 -0800 Subject: [PATCH] [ET-VK] support biases in buffer-based linear shader ## Context As title. Add support for biases in the buffer-based addmm and linear implementation. Differential Revision: [D69247282](https://our.internmc.facebook.com/intern/diff/D69247282/) ghstack-source-id: 265095770 Pull Request resolved: https://github.com/pytorch/executorch/pull/8284 --- ...ve_buffer.glsl => addmm_naive_buffer.glsl} | 23 ++++-- ...ve_buffer.yaml => addmm_naive_buffer.yaml} | 5 +- .../vulkan/runtime/graph/ops/impl/Linear.cpp | 74 +++++++++++++++++-- backends/vulkan/test/op_tests/cases.py | 27 +------ backends/vulkan/test/test_vulkan_delegate.py | 23 ++++++ 5 files changed, 116 insertions(+), 36 deletions(-) rename backends/vulkan/runtime/graph/ops/glsl/{matmul_naive_buffer.glsl => addmm_naive_buffer.glsl} (73%) rename backends/vulkan/runtime/graph/ops/glsl/{matmul_naive_buffer.yaml => addmm_naive_buffer.yaml} (81%) diff --git a/backends/vulkan/runtime/graph/ops/glsl/matmul_naive_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/addmm_naive_buffer.glsl similarity index 73% rename from backends/vulkan/runtime/graph/ops/glsl/matmul_naive_buffer.glsl rename to backends/vulkan/runtime/graph/ops/glsl/addmm_naive_buffer.glsl index dd91685c8a7..1f3061ea100 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/matmul_naive_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/addmm_naive_buffer.glsl @@ -10,6 +10,9 @@ #define PRECISION ${PRECISION} +$if HAS_BIAS: + #define HAS_BIAS + #define T ${buffer_scalar_type(DTYPE)} ${define_required_extensions(DTYPE)} @@ -19,6 +22,8 @@ layout(std430) buffer; ${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")} ${layout_declare_tensor(B, "r", "t_mat1", DTYPE, "buffer")} ${layout_declare_tensor(B, "r", "t_mat2", DTYPE, "buffer")} +$if HAS_BIAS: + ${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer")} ${layout_declare_ubo(B, "ivec4", "out_sizes")} ${layout_declare_ubo(B, "ivec4", "out_strides")} ${layout_declare_ubo(B, "ivec4", "mat1_sizes")} @@ -26,6 +31,8 @@ ${layout_declare_ubo(B, "ivec4", "mat1_strides")} ${layout_declare_ubo(B, "ivec4", "mat2_sizes")} ${layout_declare_ubo(B, "ivec4", "mat2_strides")} ${layout_declare_ubo(B, "int", "out_numel")} +$if HAS_BIAS: + ${layout_declare_ubo(B, "float", "alpha", "float", "beta")} #include "indexing_utils.h" @@ -34,25 +41,25 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; ${layout_declare_spec_const(C, "int", "mat2_is_transposed", "0")} void main() { - const ivec4 out_bufix = ivec4( + const ivec4 out_tidx = ivec4( gl_GlobalInvocationID.x, gl_GlobalInvocationID.y, gl_GlobalInvocationID.z % out_sizes.z, gl_GlobalInvocationID.z / out_sizes.z); - if (any(greaterThanEqual(out_bufix, out_sizes))) { + if (any(greaterThanEqual(out_tidx, out_sizes))) { return; } int mat1_bufi = tidx_to_bufi( - ivec4(0, out_bufix.y, out_bufix.z, out_bufix.w), mat1_strides); + ivec4(0, out_tidx.y, out_tidx.z, out_tidx.w), mat1_strides); int mat2_bufi; if (mat2_is_transposed > 0) { mat2_bufi = tidx_to_bufi( - ivec4(0, out_bufix.x, 0, 0), mat2_strides); + ivec4(0, out_tidx.x, 0, 0), mat2_strides); } else { mat2_bufi = tidx_to_bufi( - ivec4(out_bufix.x, 0, out_bufix.z, out_bufix.w), mat2_strides); + ivec4(out_tidx.x, 0, out_tidx.z, out_tidx.w), mat2_strides); } int mat2_stride; @@ -70,6 +77,10 @@ void main() { mat2_bufi += mat2_stride; } - const int out_bufi = tidx_to_bufi(out_bufix, out_strides); + const int out_bufi = tidx_to_bufi(out_tidx, out_strides); +#ifdef HAS_BIAS + t_out[out_bufi] = T(alpha) * T(sum) + T(beta) * t_bias[out_tidx.x]; +#else t_out[out_bufi] = T(sum); +#endif // HAS_BIAS } diff --git a/backends/vulkan/runtime/graph/ops/glsl/matmul_naive_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/addmm_naive_buffer.yaml similarity index 81% rename from backends/vulkan/runtime/graph/ops/glsl/matmul_naive_buffer.yaml rename to backends/vulkan/runtime/graph/ops/glsl/addmm_naive_buffer.yaml index 54eb444f73d..b093d0c80b2 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/matmul_naive_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/addmm_naive_buffer.yaml @@ -4,13 +4,16 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -matmul_naive_buffer: +addmm_naive_buffer: parameter_names_with_default_values: DTYPE: float STORAGE: buffer + HAS_BIAS: false generate_variant_forall: DTYPE: - VALUE: float - VALUE: half shader_variants: - NAME: matmul_naive_buffer + - NAME: addmm_naive_buffer + HAS_BIAS: true diff --git a/backends/vulkan/runtime/graph/ops/impl/Linear.cpp b/backends/vulkan/runtime/graph/ops/impl/Linear.cpp index 1cba6de851c..ddcdb41ece8 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Linear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Linear.cpp @@ -84,7 +84,7 @@ struct Params final { float beta; }; -void add_addmm_naive_node( +void add_addmm_naive_texture_node( ComputeGraph& graph, const ValueRef self_data, const ValueRef mat1, @@ -134,6 +134,69 @@ void add_addmm_naive_node( {mat2_is_transposed})); } +void add_addmm_naive_buffer_node( + ComputeGraph& graph, + const ValueRef self_data, + const ValueRef mat1, + const ValueRef mat2_data, + const ValueRef beta, + const ValueRef alpha, + const ValueRef out, + const Params& params, + const ValueRef mat2_is_transposed) { + (void)beta; + (void)alpha; + ValueRef mat2 = prepack_standard( + graph, + mat2_data, + graph.storage_type_of(out), + utils::kHeightPacked, + /*passthrough = */ true); + ValueRef self = prepack_standard( + graph, + self_data, + graph.storage_type_of(out), + utils::kWidthPacked, + /*passthrough = */ true); + + std::string kernel_name = "addmm_naive_buffer"; + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + utils::uvec3 global_size = { + graph.size_at(-1, out), + graph.size_at(-2, out), + graph.size_at(-3, out) * graph.size_at(-4, out)}; + + int mat2_is_transposed_val = (mat2_is_transposed != kDummyValueRef && + graph.get_bool(mat2_is_transposed)) + ? 1 + : 0; + + graph.execute_nodes().emplace_back(new DispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_size, + graph.create_local_wg_size(global_size), + // Inputs and Outputs + {{out, vkapi::kWrite}, {{mat1, mat2, self}, vkapi::kRead}}, + // Shader params buffers + { + graph.sizes_ubo(out), + graph.strides_ubo(out), + graph.sizes_ubo(mat1), + graph.strides_ubo(mat1), + graph.sizes_ubo(mat2), + graph.strides_ubo(mat2), + graph.numel_ubo(out), + graph.create_params_buffer(params), + }, + // Specialization Constants + {mat2_is_transposed_val}, + // Resizing Logic + resize_addmm_node, + {mat2_is_transposed})); +} + void add_addmm_optimized_node( ComputeGraph& graph, const ValueRef self_data, @@ -246,11 +309,14 @@ void add_addmm_node( } Params params = {alpha_val, beta_val}; - if (graph.packed_dim_of(mat1) == WHCN::kChannelsDim) { + if (graph.is_buffer_storage(out)) { + add_addmm_naive_buffer_node( + graph, self, mat1, mat2, beta, alpha, out, params, mat2_is_transposed); + } else if (graph.packed_dim_of(mat1) == WHCN::kChannelsDim) { add_addmm_optimized_node( graph, self, mat1, mat2, beta, alpha, out, params, mat2_is_transposed); } else if (graph.packed_dim_of(mat1) == WHCN::kWidthDim) { - add_addmm_naive_node( + add_addmm_naive_texture_node( graph, self, mat1, mat2, beta, alpha, out, params, mat2_is_transposed); } else { VK_THROW("Input should be channel packed or width packed."); @@ -283,8 +349,6 @@ void linear(ComputeGraph& graph, const std::vector& args) { if (graph.val_is_none(bias)) { return add_matmul_node(graph, input, weight, out, mat2_is_transposed); } else { - // Buffer implementation does not yet support biases - VK_CHECK_COND(!graph.is_buffer_storage(out)); return add_addmm_node( graph, bias, diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 2130573c0cc..38d87240b80 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -126,7 +126,8 @@ def get_addmm_inputs(): ] -def get_linear_texture_inputs(): +@register_test_suite("aten.linear.default") +def get_linear_inputs(): MKN_list = common_MKN_list inputs_list = [((M, K), (N, K), None) for M, K, N in MKN_list] @@ -141,32 +142,10 @@ def get_linear_texture_inputs(): "utils::kWidthPacked", "utils::kChannelsPacked", ] - test_suite.test_name_suffix = "texture" - return test_suite - - -def get_linear_buffer_inputs(): - MKN_list = common_MKN_list - - inputs_list = [((M, K), (N, K), None) for M, K, N in MKN_list] - inputs_list += [((3, M, K), (N, K), None) for M, K, N in MKN_list] - - test_suite = VkTestSuite(inputs_list) - test_suite.dtypes = ["at::kFloat"] - test_suite.layouts = [ - "utils::kWidthPacked", - "utils::kChannelsPacked", - ] - test_suite.storage_types = ["utils::kBuffer"] - test_suite.test_name_suffix = "buffer" + test_suite.storage_types = ["utils::kBuffer", "utils::kTexture3D"] return test_suite -@register_test_suite("aten.linear.default") -def get_linear_test_suites(): - return [get_linear_texture_inputs(), get_linear_buffer_inputs()] - - @register_test_suite("aten._weight_int8pack_mm.default") def get_weight_int8pack_mm_inputs(): MKN_list = [ diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index 129f40df8b1..5fba5ed54cf 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -1711,3 +1711,26 @@ def forward(self, x): (torch.ones(size=[5, 4, 1, 2, 6]),), expect_no_delegates=True, ) + + def test_vulkan_backend_large_linear_layer(self): + class LinearModel(torch.nn.Module): + def __init__( + self, n_pca_basis: int, n_sh_basis: int, n_gaussians: int + ) -> None: + super(LinearModel, self).__init__() + self.fc1 = torch.nn.Linear( + n_pca_basis, (n_sh_basis + 3 + 3 + 4) * n_gaussians + ) + + def forward(self, x: torch.Tensor): + out = self.fc1(x) + return out + + n_pca_basis = 64 + n_sh_basis = 6 + n_gaussians = 2**16 + + self.lower_module_and_test_output( + LinearModel(n_pca_basis, n_sh_basis, n_gaussians), + (torch.ones(n_pca_basis),), + )