diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index 5109e198206..7fde7e04f91 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -179,6 +179,11 @@ utils::GPUMemoryLayout ComputeGraph::suggested_memory_layout( return utils::kChannelsPacked; } +bool ComputeGraph::device_name_contains(const char* substr) { + return context_->adapter_ptr()->device_name().find(substr) != + std::string::npos; +} + void ComputeGraph::check_no_active_value_ptrs() { VK_CHECK_COND( values_in_use_ == 0, diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 3d46aa327b8..d09597ad778 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -443,6 +443,15 @@ class ComputeGraph final { utils::GPUMemoryLayout suggested_memory_layout( const std::vector& sizes); + inline bool device_is_adreno() { + return context_->adapter_ptr()->device_type() == vkapi::DeviceType::ADRENO; + } + const std::string& device_name() { + return context()->adapter_ptr()->device_name(); + } + + bool device_name_contains(const char* substr); + // // Graph Building // diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_coop.glsl new file mode 100644 index 00000000000..c8ccbacffc1 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_coop.glsl @@ -0,0 +1,122 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define T ${buffer_scalar_type(DTYPE)} +#define VEC4_T ${buffer_gvec_type(DTYPE, 4)} + +#define TILE_ROWS ${TILE_ROWS} + +#define NGROUPS 8 +#define NWORKERS 8 + +${define_required_extensions(DTYPE)} + +$if WEIGHT_STORAGE == "buffer": + ${define_required_extensions("int8")} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_in", DTYPE, IN_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight", "int8", WEIGHT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_scales", DTYPE, SCALES_STORAGE, is_scalar_array=False)} + +layout(push_constant) uniform restrict Block { + ivec4 out_sizes; + ivec4 in_sizes; + ivec4 weight_sizes; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +shared VEC4_T partial_c[NGROUPS][NWORKERS][TILE_ROWS]; + +void main() { + const uint out_row = gl_GlobalInvocationID.y * TILE_ROWS; + const uint out_col = gl_GlobalInvocationID.x << 2; + + const int gid = int(gl_LocalInvocationID.x); // group id + const int wid = int(gl_LocalInvocationID.z); // worker id + + if (out_col >= out_sizes.x || out_row >= out_sizes.y) { + return; + } + + VEC4_T a[TILE_ROWS]; + VEC4_T b[4]; + VEC4_T local_c[TILE_ROWS]; + + [[unroll]] for (int i = 0; i < TILE_ROWS; ++i) { + local_c[i] = VEC4_T(0.0); + } + + $if SCALES_STORAGE == "buffer": + const VEC4_T scales = VEC4_T(t_scales[out_col >> 2]); + $else: + const VEC4_T scales = VEC4_T(texelFetch(t_scales, ivec2(out_col >> 2, 0), 0)); + + for (int pos = 4 * wid; pos < in_sizes.x; pos += (4 * NWORKERS)) { + // Preload t_weight + [[unroll]] for (int i = 0; i < 4; i++) { + $if WEIGHT_STORAGE == "buffer": + b[i] = t_weight[((pos + i) * weight_sizes.x + out_col) >> 2]; + $else: + b[i] = VEC4_T(texelFetch(t_weight, ivec2(out_col >> 2, pos + i), 0)); + } + // Preload t_in + for (int i = 0; i < TILE_ROWS; i++) { + $if IN_STORAGE == "buffer": + a[i] = t_in[((out_row + i) * in_sizes.x + pos) >> 2]; + $else: + a[i] = VEC4_T(texelFetch(t_in, ivec3(pos >> 2, out_row + i, 0), 0)); + } + + // Accumulate partial output + [[unroll]] for (int i = 0; i < TILE_ROWS; ++i) { + local_c[i] += a[i].x * b[0] + + a[i].y * b[1] + + a[i].z * b[2] + + a[i].w * b[3]; + } + } + + [[unroll]] for (int i = 0; i < TILE_ROWS; ++i) { + partial_c[gid][wid][i] = local_c[i]; + } + + memoryBarrierShared(); + barrier(); + + if (wid != 0) { + return; + } + + VEC4_T c[TILE_ROWS]; + + for (int row = 0; row < TILE_ROWS; ++row) { + c[row] = VEC4_T(0.0); + [[unroll]] for (int worker = 0; worker < NWORKERS; ++worker) { + c[row] += partial_c[gid][worker][row]; + } + } + + [[unroll]] for (int i = 0; i < TILE_ROWS; ++i) { + $if OUT_STORAGE == "buffer": + if (out_row + i < out_sizes.y) { + t_out[((out_row + i) * out_sizes.x + out_col) >> 2] = c[i] * scales; + } + $else: + imageStore(t_out, ivec3(out_col >> 2, out_row + i, 0), c[i] * scales); + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_coop.yaml b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_coop.yaml new file mode 100644 index 00000000000..5daf28132e6 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_coop.yaml @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +q_8w_linear_coop: + parameter_names_with_default_values: + DTYPE: float + IN_STORAGE: texture3d + OUT_STORAGE: texture3d + WEIGHT_STORAGE: texture2d + SCALES_STORAGE: texture2d + TILE_ROWS: 4 + generate_variant_forall: + TILE_ROWS: + - VALUE: 1 + SUFFIX: o4x1 + shader_variants: + - NAME: q_8w_linear_coop_texture3d_texture3d_texture2d_texture2d_float + - NAME: q_8w_linear_coop_buffer_buffer_texture2d_texture2d_float + IN_STORAGE: buffer + OUT_STORAGE: buffer + - NAME: q_8w_linear_coop_buffer_buffer_buffer_buffer_float + IN_STORAGE: buffer + OUT_STORAGE: buffer + WEIGHT_STORAGE: buffer + SCALES_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt8.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt8.cpp index 64c2d202529..4a10f469be0 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt8.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt8.cpp @@ -142,6 +142,7 @@ void add_q_8w_linear_node( void add_q_8w_linear_tiled_node( ComputeGraph& graph, + const bool use_coop_algorithm, const ValueRef mat1, const ValueRef q_mat2_data, const ValueRef scales_data, @@ -168,7 +169,8 @@ void add_q_8w_linear_tiled_node( ValueRef scales = prepack_standard(graph, scales_data, scales_storage, utils::kWidthPacked); - std::string kernel_name = "q_8w_linear_tiled"; + std::string kernel_name = + use_coop_algorithm ? "q_8w_linear_coop" : "q_8w_linear_tiled"; kernel_name.reserve(kShaderNameReserve); add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); add_storage_type_suffix(kernel_name, graph.storage_type_of(mat1)); @@ -197,6 +199,9 @@ void add_q_8w_linear_tiled_node( global_wg_size[1] = global_wg_size[1] / out_tile_nrows; utils::uvec3 local_wg_size{64, 1, 1}; + if (use_coop_algorithm) { + local_wg_size = {8, 1, 8}; + } graph.execute_nodes().emplace_back(new DispatchNode( graph, @@ -257,13 +262,26 @@ bool can_use_tiled_impl( return true; } +bool can_use_coop_impl(ComputeGraph& graph, const ValueRef mat1) { + // Do not use coop algorithm for Adreno 702; manual experimentation shows that + // it performs worse than the tiled algorithm. + // TODO(ssjia): Determine a more robust heuristic to determine when the coop + // algorithm should be used, instead of depending on specific device identity. + if (graph.device_is_adreno() && graph.device_name_contains("702")) { + return false; + } + // Check that the computation is vector * matrix + return (graph.size_at(-2, mat1) == 1); +} + void weight_int8pack_mm( ComputeGraph& graph, const std::vector& args) { check_q_8w_linear_args(graph, args[0], args[1], args[2], args[3]); if (can_use_tiled_impl(graph, args[0], args[1], args[2], args[3])) { + bool use_coop_algorithm = can_use_coop_impl(graph, args[0]); return add_q_8w_linear_tiled_node( - graph, args[0], args[1], args[2], args[3]); + graph, use_coop_algorithm, args[0], args[1], args[2], args[3]); } return add_q_8w_linear_node(graph, args[0], args[1], args[2], args[3]); } diff --git a/backends/vulkan/runtime/vk_api/Adapter.h b/backends/vulkan/runtime/vk_api/Adapter.h index d73ed1bc0ce..8ae61095be8 100644 --- a/backends/vulkan/runtime/vk_api/Adapter.h +++ b/backends/vulkan/runtime/vk_api/Adapter.h @@ -122,6 +122,15 @@ class Adapter final { return physical_device_.timestamp_period; } + // Device Identity + inline const std::string& device_name() const { + return physical_device_.device_name; + } + + inline vkapi::DeviceType device_type() const { + return physical_device_.device_type; + } + // Queue Management Queue request_queue(); diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 525f74609a6..4a12f16bbf9 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -152,6 +152,9 @@ def get_linear_inputs(): @register_test_suite("aten._weight_int8pack_mm.default") def get_weight_int8pack_mm_inputs(): MKN_list = [ + [1, 480, 256], + [1, 1024, 1024], + [1, 1024, 256], [3, 480, 256], [6, 480, 256], [6, 256, 1024],