From b24a5941e9faf2528dff2b9d04965a50932a7356 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Tue, 12 Mar 2024 08:59:16 -0700 Subject: [PATCH] [ET-VK] Introduce graph runtime shader library that enables dynamic shapes ## Context https://github.com/pytorch/pytorch/pull/121598 introduces the ability to support dynamic shapes through tensor metadata updates. The idea is fairly simple. Instead of shaders accepting a UBO with size data for all arguments: ``` layout(set = 0, binding = 2) uniform PRECISION restrict Block { ivec4 output_sizes; ivec4 other_sizes; float alpha; } ``` Shaders will accept separate UBOs for each piece of tensor metadata: ``` layout(set = 0, binding = 3) uniform PRECISION restrict OutSizes { ivec4 data; } out_sizes; layout(set = 0, binding = 4) uniform PRECISION restrict InSizes { ivec4 data; } in_sizes; layout(set = 0, binding = 5) uniform PRECISION restrict OtherSizes { ivec4 data; } other_sizes; layout(set = 0, binding = 6) uniform PRECISION restrict Alpha { float data; } alpha; ``` Each UBO will be owned and maintained by the corresponding `vTensor` instance. To support a graph input resize, every tensor in the graph only needs to update their metadata UBOs via the `tensor.virtual_resize(new_sizes)` call. Shader dispatches in subsequent command buffer submissions will then see the updated metadata and execute as if the tensor were the updated sizes. This changeset introduces a new shader library for the Vulkan graph runtime that enables dynamic shapes through this technique in favor of relying on the shader library from PyTorch Vulkan. ## Considerations Technically, the UBO update technique can be applied to the shaders from PyTorch Vulkan as well. If that's the case, why introduce a new shader library for the graph runtime? The primary motivation is code quality. First, having `vTensor` supply UBOs for their own metadata greatly reduces the need to have operator specifc ad-hoc `Params` structs to organize arguments to write into a `api::UniformParamsBuffer`. Constructing an `ExecuteNode` for binary operators is now ``` graph.execute_nodes().emplace_back(new ExecuteNode( graph, api::shader_registry().get_shader_info(kernel_name.str()), global_size, local_size, {{out, api::MemoryAccessType::WRITE}, {{arg1, arg2}, api::MemoryAccessType::READ}}, {t_out.gpu_sizes_ubo(), t_in1.gpu_sizes_ubo(), t_in2.gpu_sizes_ubo(), graph.create_params_buffer(alpha_val)})) ``` instead of ``` ArithmeticParams block{ get_size_as_ivec4(t_out), get_size_as_ivec4(t_in1), get_size_as_ivec4(t_in2), alpha_val, }; api::UniformParamsBuffer params(graph.context(), block); graph.execute_nodes().emplace_back(new ExecuteNode( graph, shader, global_size, local_size, {{out, api::MemoryAccessType::WRITE}, {{arg1, arg2}, api::MemoryAccessType::READ}}, std::move(params))); ``` Another consideration is that https://github.com/pytorch/pytorch/pull/115948 which was landed fairly recently enables much more expressive shader templates through the use of Python code blocks in the GLSL template. This enables shader templates that can easily express variants for different data types, packing structures, etc. Introducing a new shader library provides the opportunity to rewrite the shaders in PyTorch Vulkan in a more generic and extensible way. Differential Revision: [D54754545](https://our.internmc.facebook.com/intern/diff/D54754545/) [ghstack-poisoned] --- .../runtime/graph/ops/glsl/all_shaders.yaml | 62 +++++++ .../runtime/graph/ops/glsl/binary_op.glsl | 67 +++++++ .../graph/ops/glsl/broadcasting_utils.h | 17 ++ .../runtime/graph/ops/glsl/image_to_nchw.glsl | 61 +++++++ .../runtime/graph/ops/glsl/indexing_utils.h | 17 ++ .../runtime/graph/ops/glsl/nchw_to_image.glsl | 61 +++++++ .../ops/impl/{Arithmetic.cpp => BinaryOp.cpp} | 72 ++++---- .../vulkan/runtime/graph/ops/impl/Staging.cpp | 44 +---- .../vulkan/runtime/graph/ops/impl/Staging.h | 10 +- .../graph/ops/utils/ShaderNameUtils.cpp | 51 ++++++ .../Arithmetic.h => utils/ShaderNameUtils.h} | 23 +-- .../runtime/graph/ops/utils/StagingUtils.cpp | 122 ++++--------- backends/vulkan/targets.bzl | 63 ++++++- .../vulkan/test/vulkan_compute_api_test.cpp | 168 +++++++++--------- 14 files changed, 569 insertions(+), 269 deletions(-) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/all_shaders.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/broadcasting_utils.h create mode 100644 backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h create mode 100644 backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl rename backends/vulkan/runtime/graph/ops/impl/{Arithmetic.cpp => BinaryOp.cpp} (77%) create mode 100644 backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp rename backends/vulkan/runtime/graph/ops/{impl/Arithmetic.h => utils/ShaderNameUtils.h} (50%) diff --git a/backends/vulkan/runtime/graph/ops/glsl/all_shaders.yaml b/backends/vulkan/runtime/graph/ops/glsl/all_shaders.yaml new file mode 100644 index 00000000000..f954528ee7e --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/all_shaders.yaml @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +binary_op: + parameter_names_with_default_values: + OPERATOR: X + A * Y + NDIM: 3 + DTYPE: float + PACKING: CHANNELS_PACKED + generate_variant_forall: + DTYPE: + - VALUE: "half" + SUFFIX: "half" + - VALUE: "float" + SUFFIX: "float" + shader_variants: + - NAME: binary_add + - NAME: binary_sub + OPERATOR: X - Y + - NAME: binary_mul + OPERATOR: X * Y + - NAME: binary_div + OPERATOR: X / Y + - NAME: binary_pow + OPERATOR: pow(X, Y) + - NAME: binary_floor_divide + OPERATOR: floor(X / Y) + +image_to_nchw: + parameter_names_with_default_values: + NDIM: 3 + DTYPE: float + PACKING: CHANNELS_PACKED + generate_variant_forall: + DTYPE: + - VALUE: "half" + SUFFIX: "half" + - VALUE: "float" + SUFFIX: "float" + shader_variants: + - NAME: image3d_to_nchw_C_packed + - NAME: image2d_to_nchw_C_packed + NDIM: 2 + +nchw_to_image: + parameter_names_with_default_values: + NDIM: 3 + DTYPE: float + PACKING: CHANNELS_PACKED + generate_variant_forall: + DTYPE: + - VALUE: "half" + SUFFIX: "half" + - VALUE: "float" + SUFFIX: "float" + shader_variants: + - NAME: nchw_to_image3d_C_packed + - NAME: nchw_to_image2d_C_packed + NDIM: 2 diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl b/backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl new file mode 100644 index 00000000000..4e67eac3a55 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl @@ -0,0 +1,67 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#include "indexing_utils.h" +#include "broadcasting_utils.h" + +#define PRECISION ${PRECISION} + +#define OP(X, Y, A) ${OPERATOR} + +layout(std430) buffer; + +layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; +layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; +layout(set = 0, binding = 2) uniform PRECISION sampler3D image_other; + +layout(set = 0, binding = 3) uniform PRECISION restrict OutSizes { + ivec4 data; +} +out_sizes; + +layout(set = 0, binding = 4) uniform PRECISION restrict InSizes { + ivec4 data; +} +in_sizes; + +layout(set = 0, binding = 5) uniform PRECISION restrict OtherSizes { + ivec4 data; +} +other_sizes; + +layout(set = 0, binding = 6) uniform PRECISION restrict Alpha { + float data; +} +alpha; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + const ivec4 coord = POS_TO_COORD_${PACKING}(pos, out_sizes.data); + + if (any(greaterThanEqual(coord, out_sizes.data))) { + return; + } + + ivec4 in_coord = out_coord_to_in_coord(coord, in_sizes.data); + vec4 in_texel = texelFetch( + image_in, + COORD_TO_POS_${PACKING}(in_coord, in_sizes.data), + 0); + + ivec4 other_coord = out_coord_to_in_coord(coord, other_sizes.data); + vec4 other_texel = texelFetch( + image_other, + COORD_TO_POS_${PACKING}(other_coord, other_sizes.data), + 0); + + imageStore(image_out, pos, OP(in_texel, other_texel, alpha.data)); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/broadcasting_utils.h b/backends/vulkan/runtime/graph/ops/glsl/broadcasting_utils.h new file mode 100644 index 00000000000..dc8635b8813 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/broadcasting_utils.h @@ -0,0 +1,17 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +ivec4 out_coord_to_in_coord(const ivec4 out_coord, const ivec4 in_sizes) { + ivec4 in_coord = out_coord; + for (int i = 0; i < 4; ++i) { + if (in_sizes[i] == 1) { + in_coord[i] = 0; + } + } + return in_coord; +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl b/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl new file mode 100644 index 00000000000..f966f7584b2 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl @@ -0,0 +1,61 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#include "indexing_utils.h" + +layout(std430) buffer; + +layout(set = 0, binding = 0) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} image_in; +layout(set = 0, binding = 1) buffer PRECISION restrict writeonly Buffer { + ${T[DTYPE]} data[]; +} +buffer_out; + +layout(set = 0, binding = 2) uniform PRECISION restrict GpuSizes { + ivec4 data; +} +gpu_sizes; + +layout(set = 0, binding = 3) uniform PRECISION restrict CpuSizes { + ivec4 data; +} +cpu_sizes; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + const ivec4 coord = POS_TO_COORD_${PACKING}(pos, gpu_sizes.data); + + if (any(greaterThanEqual(coord, gpu_sizes.data))) { + return; + } + + const ${VEC4_T[DTYPE]} intex = texelFetch(image_in, ${GET_POS[NDIM]("pos")}, 0); + + const int base_index = COORD_TO_BUFFER_IDX(coord, cpu_sizes.data); + const ivec4 buf_indices = + base_index + ivec4(0, 1, 2, 3) * (gpu_sizes.data.x * gpu_sizes.data.y); + + if (coord.z < cpu_sizes.data.z) { + buffer_out.data[buf_indices.x] = intex.x; + } + if (coord.z + 1 < cpu_sizes.data.z) { + buffer_out.data[buf_indices.y] = intex.y; + } + if (coord.z + 2 < cpu_sizes.data.z) { + buffer_out.data[buf_indices.z] = intex.z; + } + if (coord.z + 3 < cpu_sizes.data.z) { + buffer_out.data[buf_indices.w] = intex.w; + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h new file mode 100644 index 00000000000..7bac6b5116e --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h @@ -0,0 +1,17 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#define POS_TO_COORD_CHANNELS_PACKED(pos, sizes) \ + ivec4(pos.x, pos.y, (pos.z * 4) % sizes.z, (pos.z * 4) / sizes.z) + +#define COORD_TO_POS_CHANNELS_PACKED(coord, sizes) \ + ivec3(coord.x, coord.y, (coord.z + coord.w * sizes.z) / 4) + +#define COORD_TO_BUFFER_IDX(coord, sizes) \ + coord.x + coord.y* sizes.x + coord.z* sizes.y* sizes.x + \ + coord.w* sizes.z* sizes.y* sizes.x; diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl new file mode 100644 index 00000000000..00ed3fe5e48 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl @@ -0,0 +1,61 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#include "indexing_utils.h" + +layout(std430) buffer; + +layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; +layout(set = 0, binding = 1) buffer PRECISION restrict readonly Buffer { + ${T[DTYPE]} data[]; +} +buffer_in; + +layout(set = 0, binding = 2) uniform PRECISION restrict GpuSizes { + ivec4 data; +} +gpu_sizes; + +layout(set = 0, binding = 3) uniform PRECISION restrict CpuSizes { + ivec4 data; +} +cpu_sizes; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + const ivec4 coord = POS_TO_COORD_${PACKING}(pos, gpu_sizes.data); + + if (any(greaterThanEqual(coord, gpu_sizes.data))) { + return; + } + + const int base_index = COORD_TO_BUFFER_IDX(coord, cpu_sizes.data); + const ivec4 buf_indices = + base_index + ivec4(0, 1, 2, 3) * (gpu_sizes.data.x * gpu_sizes.data.y); + + ${T[DTYPE]} val_x = buffer_in.data[buf_indices.x]; + ${T[DTYPE]} val_y = buffer_in.data[buf_indices.y]; + ${T[DTYPE]} val_z = buffer_in.data[buf_indices.z]; + ${T[DTYPE]} val_w = buffer_in.data[buf_indices.w]; + + ${VEC4_T[DTYPE]} texel = ${VEC4_T[DTYPE]}(val_x, val_y, val_z, val_w); + + if (coord.z + 3 >= cpu_sizes.data.z) { + ivec4 c_ind = ivec4(coord.z) + ivec4(0, 1, 2, 3); + vec4 valid_c = vec4(lessThan(c_ind, ivec4(cpu_sizes.data.z))); + texel = texel * valid_c; + } + + imageStore(image_out, ${GET_POS[NDIM]("pos")}, texel); +} diff --git a/backends/vulkan/runtime/graph/ops/impl/Arithmetic.cpp b/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp similarity index 77% rename from backends/vulkan/runtime/graph/ops/impl/Arithmetic.cpp rename to backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp index 453e290045c..887b529b208 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Arithmetic.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp @@ -6,8 +6,6 @@ * LICENSE file in the root directory of this source tree. */ -#include - #include #include @@ -15,32 +13,15 @@ #include #include +#include + namespace at { namespace native { namespace vulkan { -#define DEFINE_ARITHMETIC_WITH_ALPHA_FN(function, shader) \ - void function(ComputeGraph& graph, const std::vector& args) { \ - return add_arithmetic_node( \ - graph, args[0], args[1], args[2], args[3], VK_KERNEL(shader)); \ - } - -#define DEFINE_ARITHMETIC_FN(function, shader) \ - void function(ComputeGraph& graph, const std::vector& args) { \ - return add_arithmetic_node( \ - graph, args[0], args[1], kDummyValueRef, args[2], VK_KERNEL(shader)); \ - } - -DEFINE_ARITHMETIC_WITH_ALPHA_FN(add, add); -DEFINE_ARITHMETIC_WITH_ALPHA_FN(sub, sub); - -// Floor div does not have an alpha, but a string argument (which is unused) is -// passed in at the same location as the alpha argument in other op. -DEFINE_ARITHMETIC_WITH_ALPHA_FN(floor_div, floor_divide); - -DEFINE_ARITHMETIC_FN(mul, mul); -DEFINE_ARITHMETIC_FN(div, div); -DEFINE_ARITHMETIC_FN(pow, pow); +std::string get_arithmetic_shader_name(const std::string& op_name) { + return "arithmetic_" + op_name; +} void add_arithmetic_node( ComputeGraph& graph, @@ -48,7 +29,7 @@ void add_arithmetic_node( const ValueRef in2, const ValueRef alpha, const ValueRef out, - const api::ShaderInfo& shader) { + const std::string& op_name) { ValueRef arg1 = prepack_if_tensor_ref(graph, in1); ValueRef arg2 = prepack_if_tensor_ref(graph, in2); @@ -56,7 +37,7 @@ void add_arithmetic_node( vTensor& t_in2 = graph.get_val(arg2).toTensor(); vTensor& t_out = graph.get_val(out).toTensor(); - api::utils::uvec3 global_size = t_out.extents(); + api::utils::uvec3 global_size = t_out.virtual_extents(); api::utils::uvec3 local_size = adaptive_work_group_size(global_size); float alpha_val = 1.0f; @@ -66,23 +47,46 @@ void add_arithmetic_node( alpha_val = extract_scalar(graph.get_val(alpha)); } - ArithmeticParams block{ - get_size_as_ivec4(t_out), - get_size_as_ivec4(t_in1), - get_size_as_ivec4(t_in2), - alpha_val, - }; + std::stringstream kernel_name; + kernel_name << "binary_" << op_name; + apply_dtype_suffix(kernel_name, t_out); graph.execute_nodes().emplace_back(new ExecuteNode( graph, - shader, + VK_KERNEL_FROM_STR(kernel_name.str()), global_size, local_size, {{out, api::MemoryAccessType::WRITE}, {{arg1, arg2}, api::MemoryAccessType::READ}}, - {graph.create_params_buffer(block)})); + {t_out.gpu_sizes_ubo(), + t_in1.gpu_sizes_ubo(), + t_in2.gpu_sizes_ubo(), + graph.create_params_buffer(alpha_val)})); } +#define DEFINE_ARITHMETIC_WITH_ALPHA_FN(function, shader) \ + void function(ComputeGraph& graph, const std::vector& args) { \ + return add_arithmetic_node( \ + graph, args[0], args[1], args[2], args[3], #shader); \ + } + +#define DEFINE_ARITHMETIC_FN(function, shader) \ + void function(ComputeGraph& graph, const std::vector& args) { \ + return add_arithmetic_node( \ + graph, args[0], args[1], kDummyValueRef, args[2], #shader); \ + } + +DEFINE_ARITHMETIC_WITH_ALPHA_FN(add, add); +DEFINE_ARITHMETIC_WITH_ALPHA_FN(sub, sub); + +// Floor div does not have an alpha, but a string argument (which is unused) is +// passed in at the same location as the alpha argument in other op. +DEFINE_ARITHMETIC_WITH_ALPHA_FN(floor_div, floor_divide); + +DEFINE_ARITHMETIC_FN(mul, mul); +DEFINE_ARITHMETIC_FN(div, div); +DEFINE_ARITHMETIC_FN(pow, pow); + REGISTER_OPERATORS { VK_REGISTER_OP(aten.add.Tensor, add); VK_REGISTER_OP(aten.sub.Tensor, sub); diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp index 1659a030ff4..b3319e6dac8 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp @@ -17,22 +17,6 @@ namespace at { namespace native { namespace vulkan { -StagingParams create_staging_params(const vTensor& t) { - int32_t height = api::utils::safe_downcast(dim_at(t)); - int32_t width = api::utils::safe_downcast(dim_at(t)); - int32_t channels = - api::utils::safe_downcast(dim_at(t)); - - int32_t plane_size = height * width; - int32_t c_depth = api::utils::div_up(channels, 4); - - return { - api::utils::make_ivec3(t.extents()), - plane_size, - {c_depth, channels}, - }; -} - void add_staging_to_tensor_node( ComputeGraph& graph, const ValueRef in_staging, @@ -52,7 +36,7 @@ void add_staging_to_tensor_node( local_size, {{out_tensor, api::MemoryAccessType::WRITE}, {in_staging, api::MemoryAccessType::READ}}, - {graph.create_params_buffer(create_staging_params(t_out))})); + {t_out.gpu_sizes_ubo(), t_out.cpu_sizes_ubo()})); } void add_tensor_to_staging_node( @@ -67,26 +51,6 @@ void add_tensor_to_staging_node( api::utils::uvec3 global_size = t_in.extents(); api::utils::uvec3 local_size = adaptive_work_group_size(global_size); - StagingParams sp = create_staging_params(t_in); - - // TODO(T181194784): These are workgroup sizes for special cases. Refactor the - // calculation of workgroup sizes to a standalone function. We should use - // scalar type to get the shader name, and use the shader name to get the - // workgroup size. - if (t_in.dtype() == api::ScalarType::QUInt8 || - t_in.dtype() == api::ScalarType::QInt8 || t_in.dtype() == api::kBool) { - if (sp.plane_size % 4 == 0) { - global_size.data[0u] = sp.plane_size / 4; - global_size.data[1u] = 1; - local_size.data[0u] *= local_size.data[1u]; - local_size.data[1u] = 1; - } else { - uint32_t numel = t_in.numel(); - global_size = {api::utils::div_up(numel, uint32_t(4)), 1u, 1u}; - local_size = {64u, 1u, 1u}; - } - } - graph.execute_nodes().emplace_back(new ExecuteNode( graph, shader, @@ -94,7 +58,7 @@ void add_tensor_to_staging_node( local_size, {{in_tensor, api::MemoryAccessType::READ}, {out_staging, api::MemoryAccessType::WRITE}}, - {graph.create_params_buffer(sp)})); + {t_in.gpu_sizes_ubo(), t_in.cpu_sizes_ubo()})); } ValueRef prepack(ComputeGraph& graph, const ValueRef vref) { @@ -107,8 +71,6 @@ ValueRef prepack(ComputeGraph& graph, const ValueRef vref) { api::utils::uvec3 global_size = t.extents(); api::utils::uvec3 local_size = adaptive_work_group_size(global_size); - StagingParams sp = create_staging_params(t); - graph.prepack_nodes().emplace_back(new PrepackNode( graph, shader, @@ -116,7 +78,7 @@ ValueRef prepack(ComputeGraph& graph, const ValueRef vref) { local_size, vref, v, - {graph.create_params_buffer(sp)})); + {t.gpu_sizes_ubo(), t.cpu_sizes_ubo()})); return v; } diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.h b/backends/vulkan/runtime/graph/ops/impl/Staging.h index 99bdf667c6b..425d77489fe 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.h +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.h @@ -22,22 +22,14 @@ void add_staging_to_tensor_node( ComputeGraph& graph, const ValueRef in_staging, const ValueRef out_tensor); + void add_tensor_to_staging_node( ComputeGraph& graph, const ValueRef in_tensor, const ValueRef out_staging); -struct StagingParams final { - api::utils::ivec3 extents; - int32_t plane_size; - api::utils::ivec2 channel_info; -}; - ValueRef prepack_if_tensor_ref(ComputeGraph& graph, const ValueRef v); -// Expose for the Vulkan Compute API tests. -StagingParams create_staging_params(const vTensor& t); - } // namespace vulkan } // namespace native } // namespace at diff --git a/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp new file mode 100644 index 00000000000..e941f32e162 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp @@ -0,0 +1,51 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace at { +namespace native { +namespace vulkan { + +void apply_dtype_suffix(std::stringstream& kernel_name, const vTensor& tensor) { + switch (tensor.image().format()) { + case VK_FORMAT_R32G32B32A32_SFLOAT: + kernel_name << "_float"; + break; + case VK_FORMAT_R16G16B16A16_SFLOAT: + kernel_name << "_half"; + break; + case VK_FORMAT_R32G32B32A32_SINT: + kernel_name << "_int"; + break; + default: + break; + } +} + +void apply_memory_layout_suffix( + std::stringstream& kernel_name, + const vTensor& tensor) { + switch (tensor.gpu_memory_layout()) { + case api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED: + kernel_name << "_C_packed"; + break; + case api::GPUMemoryLayout::TENSOR_HEIGHT_PACKED: + kernel_name << "_H_packed"; + break; + case api::GPUMemoryLayout::TENSOR_WIDTH_PACKED: + kernel_name << "_W_packed"; + break; + default: + break; + } +} + +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/backends/vulkan/runtime/graph/ops/impl/Arithmetic.h b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h similarity index 50% rename from backends/vulkan/runtime/graph/ops/impl/Arithmetic.h rename to backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h index b81ee21e648..b4c6c3a6bcc 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Arithmetic.h +++ b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h @@ -10,26 +10,19 @@ #ifdef USE_VULKAN_API -#include +#include + +#include namespace at { namespace native { namespace vulkan { -void add_arithmetic_node( - ComputeGraph& graph, - const ValueRef in1, - const ValueRef in2, - const ValueRef alpha, - const ValueRef out, - const api::ShaderInfo& shader); - -struct ArithmeticParams final { - api::utils::ivec4 outputSizes; - api::utils::ivec4 input1Sizes; - api::utils::ivec4 input2Sizes; - float alpha; -}; +void apply_dtype_suffix(std::stringstream& kernel_name, const vTensor& tensor); + +void apply_memory_layout_suffix( + std::stringstream& kernel_name, + const vTensor& tensor); } // namespace vulkan } // namespace native diff --git a/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp index 50f812df841..45307c8a9d9 100644 --- a/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp +++ b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp @@ -8,6 +8,7 @@ // @lint-ignore-every CLANGTIDY facebook-security-vulnerable-memcpy +#include #include #include @@ -92,101 +93,50 @@ void copy_staging_to_ptr( api::ShaderInfo get_nchw_to_image_shader(const vTensor& v_dst) { if (v_dst.is_quantized()) { - switch (v_dst.storage_type()) { - case api::StorageType::TEXTURE_3D: - switch (v_dst.dtype()) { - case api::ScalarType::QUInt8: - return VK_KERNEL(nchw_to_image_uint8); - case api::ScalarType::QInt8: - return VK_KERNEL(nchw_to_image_int8); - case api::ScalarType::QInt32: - return VK_KERNEL(nchw_to_image_int32); - default: - VK_THROW( - "Vulkan quantization currently not supported for dtype ", - v_dst.dtype()); - } - case api::StorageType::TEXTURE_2D: - switch (v_dst.dtype()) { - case api::ScalarType::QUInt8: - return VK_KERNEL(nchw_to_image2d_uint8); - case api::ScalarType::QInt8: - return VK_KERNEL(nchw_to_image2d_int8); - case api::ScalarType::QInt32: - return VK_KERNEL(nchw_to_image2d_int32); - default: - VK_THROW( - "Vulkan quantization currently not supported for dtype ", - v_dst.dtype()); - } - default: - VK_THROW("No kernel available!"); - case api::StorageType::BUFFER: - case api::StorageType::UNKNOWN: - VK_THROW("Requested storage type must be a texture type."); - } + VK_THROW("Quantized Tensors are currently not supported!"); } - if (v_dst.dtype() == api::kFloat) { - switch (v_dst.storage_type()) { - case api::StorageType::TEXTURE_3D: - return VK_KERNEL(nchw_to_image); - case api::StorageType::TEXTURE_2D: - return VK_KERNEL(nchw_to_image2d); - default: - VK_THROW("No kernel available!"); - } - } else if (v_dst.dtype() == api::kBool) { - switch (v_dst.storage_type()) { - case api::StorageType::TEXTURE_3D: - return VK_KERNEL(nchw_to_image_bool); - default: - VK_THROW("No kernel available!"); - } - } else { - VK_THROW("Unsupported dtype!"); + std::stringstream kernel_name; + + switch (v_dst.storage_type()) { + case api::StorageType::TEXTURE_3D: + kernel_name << "nchw_to_image3d"; + break; + case api::StorageType::TEXTURE_2D: + kernel_name << "nchw_to_image2d"; + break; + default: + VK_THROW("No kernel available!"); } + + apply_memory_layout_suffix(kernel_name, v_dst); + apply_dtype_suffix(kernel_name, v_dst); + + return VK_KERNEL_FROM_STR(kernel_name.str()); } api::ShaderInfo get_image_to_nchw_shader(const vTensor& v_src) { - if (v_src.is_quantized() || v_src.dtype() == api::kBool) { - auto plane_size = - dim_at(v_src) * dim_at(v_src); - switch (v_src.storage_type()) { - case api::StorageType::TEXTURE_3D: - switch (v_src.dtype()) { - case api::ScalarType::QUInt8: - case api::ScalarType::QInt8: - case api::kBool: - return plane_size % 4 == 0 ? VK_KERNEL(image_to_nchw_quantized_mul4) - : VK_KERNEL(image_to_nchw_uint); - case api::ScalarType::QInt32: - return VK_KERNEL(image_to_nchw_int32); - default: - VK_THROW( - "Vulkan quantization currently not supported for dtype ", - v_src.dtype()); - } - default: - VK_THROW("No kernel available!"); - case api::StorageType::BUFFER: - case api::StorageType::UNKNOWN: - VK_THROW("Requested storage type must be a texture type."); - } + if (v_src.is_quantized()) { + VK_THROW("Quantized Tensors are currently not supported!"); } - if (v_src.dtype() == api::kFloat) { - switch (v_src.storage_type()) { - case api::StorageType::TEXTURE_3D: - return VK_KERNEL(image_to_nchw); - case api::StorageType::TEXTURE_2D: - return VK_KERNEL(image2d_to_nchw); - default: - VK_THROW("No kernel available!"); - } - } else { - VK_THROW("Unsupported dtype!"); + std::stringstream kernel_name; + + switch (v_src.storage_type()) { + case api::StorageType::TEXTURE_3D: + kernel_name << "image3d_to_nchw"; + break; + case api::StorageType::TEXTURE_2D: + kernel_name << "image2d_to_nchw"; + break; + default: + VK_THROW("No kernel available!"); } + + apply_memory_layout_suffix(kernel_name, v_src); + apply_dtype_suffix(kernel_name, v_src); + + return VK_KERNEL_FROM_STR(kernel_name.str()); } } // namespace vulkan diff --git a/backends/vulkan/targets.bzl b/backends/vulkan/targets.bzl index 76a3bd61ee9..d4f58062b17 100644 --- a/backends/vulkan/targets.bzl +++ b/backends/vulkan/targets.bzl @@ -1,5 +1,54 @@ +load("@fbsource//tools/build_defs:fbsource_utils.bzl", "is_fbcode") +load("@fbsource//tools/build_defs:glob_defs.bzl", "subdir_glob") load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") +def get_glsl_image_format(): + if native.read_config("pt", "vulkan_full_precision", "0") == "0": + return "rgba16f" + return "rgba32f" + +def vulkan_spv_shader_lib(name, spv_filegroup): + gen_aten_vulkan_spv_target = "//caffe2/tools:gen_aten_vulkan_spv_bin" + glslc_path = "//caffe2/fb/vulkan/dotslash:glslc" + if is_fbcode(): + gen_aten_vulkan_spv_target = "//caffe2:gen_vulkan_spv_bin" + glslc_path = "//caffe2/fb/vulkan/tools:glslc" + + genrule_cmd = [ + "$(exe {})".format(gen_aten_vulkan_spv_target), + "--glsl-paths $(location {})".format(spv_filegroup), + "--output-path $OUT --env FLOAT_IMAGE_FORMAT={}".format(get_glsl_image_format()), + "--glslc-path=$(exe {})".format(glslc_path), + "--tmp-dir-path=$OUT", + ] + + genrule_name = "gen_{}_cpp".format(name) + runtime.genrule( + name = genrule_name, + outs = { + "{}.cpp".format(name): ["spv.cpp"], + }, + cmd = " ".join(genrule_cmd), + default_outs = ["."], + labels = ["uses_dotslash"], + ) + + runtime.cxx_library( + name = name, + srcs = [ + ":{}[{}.cpp]".format(genrule_name, name), + ], + # Static initialization is used to register shaders to the global shader registry, + # therefore link_whole must be True to make sure unused symbols are not discarded. + # @lint-ignore BUCKLINT: Avoid `link_whole=True` + link_whole = True, + # Define a soname that can be used for dynamic loading in Java, Python, etc. + soname = "lib{}.$(ext)".format(name), + exported_deps = [ + "//caffe2:torch_vulkan_api", + ], + ) + def define_common_targets(): runtime.genrule( name = "gen_vk_delegate_schema", @@ -38,6 +87,18 @@ def define_common_targets(): ], ) + runtime.filegroup( + name = "vulkan_graph_runtime_shaders", + srcs = subdir_glob([ + ("runtime/graph/ops/glsl", "*"), + ]), + ) + + vulkan_spv_shader_lib( + name = "vulkan_graph_runtime_shaderlib", + spv_filegroup = ":vulkan_graph_runtime_shaders", + ) + runtime.cxx_library( name = "vulkan_graph_runtime", srcs = native.glob([ @@ -53,7 +114,7 @@ def define_common_targets(): "@EXECUTORCH_CLIENTS", ], exported_deps = [ - "//caffe2:torch_vulkan_spv", + ":vulkan_graph_runtime_shaderlib", ], define_static_target = False, # Static initialization is used to register operators to the global operator registry, diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index 9df6a8dd2d1..5d4725e5dd5 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -12,7 +12,6 @@ #include -#include #include #include @@ -101,59 +100,6 @@ TEST_F(VulkanComputeAPITest, update_params_between_submit) { check_staging_buffer(staging_buffer, 4.0f); } -TEST_F(VulkanComputeAPITest, buffer_copy_sanity_check) { - // Simple test that copies data into a and reads from a - std::vector sizes = {4, 4, 1}; - vTensor a = CREATE_FLOAT_BUFFER(sizes, /*allocate_memory = */ true); - - // Input data - std::vector data_in(a.gpu_numel()); - std::fill(data_in.begin(), data_in.end(), 2.524f); - - // Fill input tensor - fill_vtensor(a, data_in); - - // Read back data - std::vector data_out(a.gpu_numel()); - extract_vtensor(a, data_out); - - // Check output - for (const auto& d : data_out) { - EXPECT_TRUE(d == 2.524f); - } -} - -TEST_F(VulkanComputeAPITest, buffer_deferred_allocation_test) { - // Same as buffer_copy_sanity_check, but defers memory allocation - - std::vector sizes = {4, 4, 1}; - vTensor a = CREATE_FLOAT_BUFFER(sizes, /*allocate_memory = */ false); - - EXPECT_TRUE(get_vma_allocation_count() == 0); - - // Input data - std::vector data_in(a.gpu_numel()); - std::fill(data_in.begin(), data_in.end(), 1.234f); - - // Allocate memory at the last possible opportunity - api::MemoryAllocation a_mem = allocate_memory_for(a); - a.buffer().bind_allocation(a_mem); - - EXPECT_TRUE(get_vma_allocation_count() == 1); - - // Fill input tensor - fill_vtensor(a, data_in); - - // Read back data - std::vector data_out(a.gpu_numel()); - extract_vtensor(a, data_out); - - // Check output - for (const auto& d : data_out) { - EXPECT_TRUE(d == 1.234f); - } -} - TEST_F(VulkanComputeAPITest, texture_add_sanity_check) { // Simple test that performs a + b -> c @@ -502,8 +448,8 @@ TEST(VulkanComputeGraphTest, test_simple_graph) { GraphConfig config; ComputeGraph graph(config); - std::vector size_big = {4, 4, 4}; - std::vector size_small = {4, 4, 1}; + std::vector size_big = {8, 64, 124}; + std::vector size_small = {8, 1, 124}; // Build graph @@ -552,8 +498,8 @@ TEST(VulkanComputeGraphTest, test_simple_prepacked_graph) { GraphConfig config; ComputeGraph graph(config); - std::vector size_big = {4, 4, 4}; - std::vector size_small = {4, 4, 1}; + std::vector size_big = {8, 73, 62}; + std::vector size_small = {8, 73, 1}; CREATE_WEIGHT_TENSOR(w1, size_small, 3.5f); CREATE_WEIGHT_TENSOR(w2, size_small, 3.0f); @@ -605,8 +551,8 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects) { GraphConfig config; ComputeGraph graph(config); - std::vector size_big = {4, 4, 4}; - std::vector size_small = {4, 4, 1}; + std::vector size_big = {48, 54, 4}; + std::vector size_small = {48, 1, 4}; // Build graph @@ -619,11 +565,6 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects) { api::kFloat, /*shared_object_idx = */ 4); - // Allocation count will be 4: - // 1 uniform buffer for each staging shader args - // 1 staging buffer for each input tensor - EXPECT_TRUE(get_vma_allocation_count() == 4); - ValueRef c = graph.add_tensor( size_big, api::kFloat, @@ -637,12 +578,6 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects) { api::kFloat, /*shared_object_idx = */ 2); - // Allocation count will be 7, three are new: - // 1 uniform buffer for arithmetic shader args - // 1 uniform buffer for staging shader args - // 1 staging buffer for the input tensor - EXPECT_TRUE(get_vma_allocation_count() == 7); - ValueRef e = graph.add_tensor( size_big, api::kFloat, @@ -655,21 +590,9 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects) { out.value = e; out.staging = graph.set_output_tensor(out.value); - // Allocation count will be 10, three are new: - // 1 uniform buffer for arithmetic shader - // 1 uniform buffer for staging shader - // 1 staging buffer for the input tensor - EXPECT_TRUE(get_vma_allocation_count() == 10); - graph.prepare(); graph.encode_execute(); - // Allocation count will be 13, three shared objects are allocated for total: - // 4 staging buffers for each I/O tensor - // 6 uniform buffers to store params for each shader dispatch - // 3 shared objects to back tensor memory - EXPECT_TRUE(get_vma_allocation_count() == 13); - // Run graph for (float i = 4.0f; i < 30.0f; i += 7.0f) { @@ -693,3 +616,82 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects) { } } } + +TEST(VulkanComputeGraphTest, test_manual_virtual_resize) { + GraphConfig config; + ComputeGraph graph(config); + + std::vector size_big = {12, 64, 64}; + std::vector size_small = {12, 64, 64}; + + // Build graph + + IOValueRef a = graph.add_input_tensor( + size_big, + api::kFloat, + /*shared_object_idx = */ 2); + IOValueRef b = graph.add_input_tensor( + size_small, + api::kFloat, + /*shared_object_idx = */ 4); + + ValueRef c = graph.add_tensor( + size_big, + api::kFloat, + /*shared_object_idx = */ 6); + + auto addFn = VK_GET_OP_FN("aten.add.Tensor"); + addFn(graph, {a.value, b.value, kDummyValueRef, c}); + + IOValueRef d = graph.add_input_tensor( + size_small, + api::kFloat, + /*shared_object_idx = */ 2); + + ValueRef e = graph.add_tensor( + size_big, + api::kFloat, + /*shared_object_idx = */ 4); + + auto mulFn = VK_GET_OP_FN("aten.mul.Tensor"); + mulFn(graph, {c, d.value, e}); + + IOValueRef out = {}; + out.value = e; + out.staging = graph.set_output_tensor(out.value); + + graph.prepare(); + graph.encode_execute(); + + // Run graph + + std::vector> new_sizes_list = { + {8, 44, 34}, {4, 13, 56}, {8, 12, 64}, {12, 55, 33}, {4, 54, 10}}; + + for (auto& new_sizes : new_sizes_list) { + graph.get_val(a.value).toTensor().virtual_resize(new_sizes); + graph.get_val(b.value).toTensor().virtual_resize(new_sizes); + graph.get_val(c).toTensor().virtual_resize(new_sizes); + graph.get_val(d.value).toTensor().virtual_resize(new_sizes); + graph.get_val(e).toTensor().virtual_resize(new_sizes); + + float val_a = new_sizes[1] + 4.0f; + float val_b = new_sizes[2] + 1.5f; + float val_d = new_sizes[0] + 2.0f; + float val_out = (val_a + val_b) * val_d; + + fill_vtensor(graph, a, val_a); + fill_vtensor(graph, b, val_b); + fill_vtensor(graph, d, val_d); + + // Execute graph + graph.execute(); + + EXTRACT_TENSOR(out); + + // Sanity check that the values are correct + for (const auto& val : data_out) { + EXPECT_TRUE(val == val_out); + } + } +}