diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.glsl index c3bd9f41af9..8a8670b4bb3 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.glsl @@ -17,17 +17,17 @@ ${define_required_extensions(DTYPE)} -$if STORAGE == "buffer": +$if WEIGHT_STORAGE == "buffer": ${define_required_extensions("int8")} #extension GL_EXT_control_flow_attributes : require layout(std430) buffer; -${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_weight", "int8", STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_scales", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_in", DTYPE, IN_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight", "int8", WEIGHT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_scales", DTYPE, SCALES_STORAGE, is_scalar_array=False)} layout(push_constant) uniform restrict Block { @@ -50,10 +50,10 @@ void main() { VEC4_T b[4]; VEC4_T c[TILE_ROWS]; - $if STORAGE == "buffer": + $if SCALES_STORAGE == "buffer": const VEC4_T scales = VEC4_T(t_scales[out_col >> 2]); $else: - const VEC4_T scales = VEC4_T(texelFetch(t_scales, ivec3(out_col >> 2, 0, 0), 0)); + const VEC4_T scales = VEC4_T(texelFetch(t_scales, ivec2(out_col >> 2, 0), 0)); [[unroll]] for (int i = 0; i < TILE_ROWS; ++i) { c[i] = VEC4_T(0.0); @@ -62,30 +62,32 @@ void main() { for (int pos = 0; pos < in_sizes.x; pos += 4) { // Preload weight tensor [[unroll]] for (int i = 0; i < 4; i++) { - $if STORAGE == "buffer": - b[i] = t_weight[((pos + i) * B_sizes.x + out_col) >> 2]; + $if WEIGHT_STORAGE == "buffer": + b[i] = t_weight[((pos + i) * out_sizes.x + out_col) >> 2]; $else: - b[i] = VEC4_T(texelFetch(t_weight, ivec3(out_col >> 2, pos + i, 0), 0)); + b[i] = VEC4_T(texelFetch(t_weight, ivec2(out_col >> 2, pos + i), 0)); } // Preload input tensor [[unroll]] for (int i = 0; i < TILE_ROWS; i++) { - $if STORAGE == "buffer": - a[i] = t_in[((out_row + i) * in_sizes.x + (pos)) >> 2]; + $if IN_STORAGE == "buffer": + a[i] = t_in[((out_row + i) * in_sizes.x + pos) >> 2]; $else: a[i] = VEC4_T(texelFetch(t_in, ivec3(pos >> 2, out_row + i, 0), 0)); } - // Compute partial output + // Accumulate output [[unroll]] for (int i = 0; i < TILE_ROWS; ++i) { c[i] += a[i].x * b[0] + a[i].y * b[1] + a[i].z * b[2] + a[i].w * b[3]; } } - // Store output tensor + // Store to output tensor [[unroll]] for (int i = 0; i < TILE_ROWS; ++i) { - $if STORAGE == "buffer": - t_out[((out_row + i) * out_sizes.x + out_col) >> 2] = c[i] * scales; + $if OUT_STORAGE == "buffer": + if (out_row + i < out_sizes.y) { + t_out[((out_row + i) * out_sizes.x + out_col) >> 2] = c[i] * scales; + } $else: imageStore(t_out, ivec3(out_col >> 2, out_row + i, 0), c[i] * scales); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.yaml index b01af47e179..1e8a5e1fe7d 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_tiled.yaml @@ -7,12 +7,26 @@ q_8w_linear_tiled: parameter_names_with_default_values: DTYPE: float - STORAGE: texture3d + IN_STORAGE: texture3d + OUT_STORAGE: texture3d + WEIGHT_STORAGE: texture2d + SCALES_STORAGE: texture2d TILE_ROWS: 4 + generate_variant_forall: + TILE_ROWS: + - VALUE: 1 + SUFFIX: o4x1 + - VALUE: 4 + SUFFIX: o4x4 + - VALUE: 6 + SUFFIX: o4x6 shader_variants: - - NAME: q_8w_linear_tiled_o4x4_texture3d_float - STORAGE: texture3d - TILE_ROWS: 4 - - NAME: q_8w_linear_tiled_o4x6_texture3d_float - STORAGE: texture3d - TILE_ROWS: 6 + - NAME: q_8w_linear_tiled_texture3d_texture3d_texture2d_texture2d_float + - NAME: q_8w_linear_tiled_buffer_buffer_texture2d_texture2d_float + IN_STORAGE: buffer + OUT_STORAGE: buffer + - NAME: q_8w_linear_tiled_buffer_buffer_buffer_buffer_float + IN_STORAGE: buffer + OUT_STORAGE: buffer + WEIGHT_STORAGE: buffer + SCALES_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt8.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt8.cpp index 5054b2e5e9c..64c2d202529 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt8.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt8.cpp @@ -146,28 +146,53 @@ void add_q_8w_linear_tiled_node( const ValueRef q_mat2_data, const ValueRef scales_data, const ValueRef out) { - utils::StorageType stype = graph.storage_type_of(out); + utils::StorageType q_mat2_storage = utils::kTexture2D; + + uint32_t max_extent = graph.context()->adapter_ptr()->max_texture2d_dim(); + std::vector qmat2_orig_sizes = graph.sizes_of(q_mat2_data); + const int64_t ndim = graph.dim_of(q_mat2_data); + const int64_t K = qmat2_orig_sizes.at(ndim - 1); + const int64_t N = qmat2_orig_sizes.at(ndim - 2); + + if (N > max_extent * 4 || K > max_extent) { + q_mat2_storage = utils::kBuffer; + } + ValueRef q_mat2 = prepack_standard_hw_transposed( - graph, q_mat2_data, stype, utils::kWidthPacked); + graph, q_mat2_data, q_mat2_storage, utils::kWidthPacked); + + utils::StorageType scales_storage = utils::kTexture2D; + if (N > max_extent) { + scales_storage = utils::kBuffer; + } ValueRef scales = - prepack_standard(graph, scales_data, stype, utils::kWidthPacked); + prepack_standard(graph, scales_data, scales_storage, utils::kWidthPacked); std::string kernel_name = "q_8w_linear_tiled"; kernel_name.reserve(kShaderNameReserve); + add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(mat1)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(q_mat2)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(scales)); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + std::vector mat1_sizes = graph.sizes_of(mat1); const int64_t M = utils::val_at(-2, mat1_sizes); int out_tile_nrows = 4; if (M % 6 == 0) { kernel_name += "_o4x6"; out_tile_nrows = 6; + } else if (M % 4 == 0) { + kernel_name += "_o4x4"; + out_tile_nrows = 4; + } else if (M % 1 == 0) { + kernel_name += "_o4x1"; + out_tile_nrows = 1; } else { kernel_name += "_o4x4"; out_tile_nrows = 4; } - add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); - add_dtype_suffix(kernel_name, graph.dtype_of(out)); - utils::uvec3 global_wg_size = graph.logical_limits_of(out); global_wg_size[1] = global_wg_size[1] / out_tile_nrows; @@ -209,18 +234,13 @@ bool can_use_tiled_impl( if (graph.size_at(-1, mat1) % 4 != 0) { return false; } - // Check that M is a multiple of 4 or 6 - if (graph.size_at(-2, mat1) % 4 != 0 && - graph.size_at(-2, mat1) % 6 != 0) { - return false; - } - // Check that the storage type is texture - // TODO(ssjia): Add support for buffer storage in the tiled impl - if (graph.storage_type_of(out) != utils::kTexture3D) { + // Check that N is a multiple of 4 + if (graph.size_at(-1, out) % 4 != 0) { return false; } // Check that the packed dim is the width dim - if (graph.packed_dim_of(mat1) != WHCN::kWidthDim) { + if (graph.packed_dim_of(mat1) != WHCN::kWidthDim && + graph.packed_dim_of(out) != WHCN::kWidthDim) { return false; } // Check that no special axis mapping is used for the input diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index f97b2c51370..525f74609a6 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -152,6 +152,7 @@ def get_linear_inputs(): @register_test_suite("aten._weight_int8pack_mm.default") def get_weight_int8pack_mm_inputs(): MKN_list = [ + [3, 480, 256], [6, 480, 256], [6, 256, 1024], [6, 1024, 256],