From 539ee1a69f9fb239f9e28dcd6f23623551eccf6a Mon Sep 17 00:00:00 2001 From: Vivek Trivedi <5340687+trivedivivek@users.noreply.github.com> Date: Thu, 3 Apr 2025 21:54:55 -0700 Subject: [PATCH] [ET-VK] Manual sync to fbsource Contains commits squashed from https://github.com/pytorch/executorch/pull/10117 PR stack. Differential Revision: D72866962 Differential Revision: D72862490 Differential Revision: D72581293 Differential Revision: D72430290 --- .../graph/ops/glsl/native_layer_norm.glsl | 234 ++++++++++++------ .../runtime/graph/ops/glsl/permute.glsl | 14 +- .../runtime/graph/ops/glsl/q_4w_linear.glsl | 133 ---------- .../runtime/graph/ops/glsl/q_8w_linear.glsl | 54 ++-- .../graph/ops/impl/NativeLayerNorm.cpp | 15 +- .../graph/ops/impl/QuantizedLinearInt8.cpp | 58 ++--- 6 files changed, 228 insertions(+), 280 deletions(-) delete mode 100644 backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl diff --git a/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl b/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl index d6c94661ace..cdbb85da4a7 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl @@ -43,106 +43,190 @@ ${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} const lowp ivec4 out_axis_map = unhash_axis_map(out_layout); const lowp int out_packed_dim = unhash_packed_dim(out_layout); -void main() { - const ivec3 lpos = ivec3(gl_GlobalInvocationID); +#define SHARED_MEMORY_FACTOR 2 +#define MAX_WORKGROUP_SIZE 64 + +#define offset_pos_index(index) ((index) + ((index) >> 2)) + +shared VEC4_T shared_input[offset_pos_index(MAX_WORKGROUP_SIZE * SHARED_MEMORY_FACTOR)]; + +// function to reduce input data in workgroup's x dimension +void reduce_input(const int width_stride, const int shared_idx_offset) { + // wait for all shared memory writes to finish + memoryBarrierShared(); + barrier(); + + // loop log(width_stride) times + for (int current_stride = 1, index = int(gl_LocalInvocationID.x << 1); current_stride < width_stride; current_stride *= 2, index <<= 1) { + // if the index at this thread is within the width stride + if (index < width_stride) { + const int local_shared_idx = shared_idx_offset + index; + // add the value at current stride to this thread's value + shared_input[offset_pos_index(local_shared_idx)] += shared_input[offset_pos_index(local_shared_idx + current_stride)]; + } - if (any(greaterThanEqual(lpos, out_limits))) { - return; + memoryBarrierShared(); + barrier(); } +} +void main() { + const ivec3 lpos = ivec3(gl_GlobalInvocationID); const int width = int(sizes.x); + ivec3 in_pos = lpos_to_pos(lpos, in_axis_map); + + // width batch read stride + const int width_stride = int(gl_WorkGroupSize.x) * SHARED_MEMORY_FACTOR; + + // local memory starting offset for this thread + const int shared_idx_offset = width_stride * int(gl_WorkGroupSize.y * gl_LocalInvocationID.z + gl_LocalInvocationID.y); + + // local memory index for this thread + const int shared_idx = shared_idx_offset + int(gl_LocalInvocationID.x); + + // if packed dimension width if (in_packed_dim != W_DIM) { VEC4_T mean = VEC4_T(0); - VEC4_T delta = VEC4_T(0); - VEC4_T delta2 = VEC4_T(0); - VEC4_T M2 = VEC4_T(0); - - // Use Welford's online algorithm to compute mean and variance in one pass - // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm - ivec3 in_pos = lpos_to_pos(lpos, in_axis_map); - for (int w = 0; w < width; ++w) { - in_pos[in_axis_map.x] = w; - VEC4_T v = load_texel(t_in, in_pos); - delta = v - mean; - mean += delta / (w + 1); - delta2 = v - mean; - M2 += delta * delta2; + VEC4_T var = VEC4_T(0); + + // Loop over the width in stride increments + for (int width_offset = 0; width_offset < width; width_offset += width_stride) { + // Read input in shared memory + for (int si = 0; si < SHARED_MEMORY_FACTOR; si++) { + in_pos[in_axis_map.x] = width_offset + int(gl_LocalInvocationID.x + si * gl_WorkGroupSize.x); + + VEC4_T in_val = VEC4_T(0); + if (all(lessThan(in_pos, out_limits))) { + in_val = load_texel(t_in, in_pos); + } + shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = in_val; + } + + reduce_input(width_stride, shared_idx_offset); + mean += shared_input[offset_pos_index(shared_idx_offset)]; + } + + mean /= width; + + // Loop over the width in stride increments + for (int width_offset = 0; width_offset < width; width_offset += width_stride) { + // Read input in shared memory + for (int si = 0; si < SHARED_MEMORY_FACTOR; si++) { + in_pos[in_axis_map.x] = width_offset + int(gl_LocalInvocationID.x + si * gl_WorkGroupSize.x); + + VEC4_T in_val = mean; + if (all(lessThan(in_pos, out_limits))) { + in_val = load_texel(t_in, in_pos); + } + + const VEC4_T delta = in_val - mean; + shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = delta * delta; + } + + reduce_input(width_stride, shared_idx_offset); + var += shared_input[offset_pos_index(shared_idx_offset)]; } - VEC4_T var = M2 / width; + var /= width; + VEC4_T rstd = pow(var + epsilon, VEC4_T(-0.5)); VEC4_T offset = -rstd * mean; - for (int w = 0; w < width; ++w) { - in_pos[in_axis_map.x] = w; - VEC4_T v = load_texel(t_in, in_pos); - // broadcasting - VEC4_T weight = load_texel(t_weight, ivec3(w, 0, 0)).xxxx; - VEC4_T bias = load_texel(t_bias, ivec3(w, 0, 0)).xxxx; - VEC4_T outtex = (v * rstd + offset) * weight + bias; - write_texel_lpos(t_out, ivec3(w, lpos.y, lpos.z), outtex, out_axis_map); + VEC4_T v = load_texel(t_in, lpos); + VEC4_T weight = load_texel(t_weight, ivec3(lpos.x, 0, 0)).xxxx; + VEC4_T bias = load_texel(t_bias, ivec3(lpos.x, 0, 0)).xxxx; + VEC4_T outtex = (v * rstd + offset) * weight + bias; + if (all(lessThan(lpos, out_limits))) { + write_texel_lpos(t_out, ivec3(lpos.x, lpos.y, lpos.z), outtex, out_axis_map); } - write_texel(t_mean, lpos, mean); - write_texel(t_rstd, lpos, rstd); + if (gl_GlobalInvocationID.x == 0) { + write_texel(t_mean, lpos, mean); + write_texel(t_rstd, lpos, rstd); + } } else { - const int packed_width = divup4(width); - + const int last_packed_width_index = divup4(width) - 1; T mean = T(0); - T delta = T(0); - T delta2 = T(0); - T M2 = T(0); - // Use Welford's online algorithm to compute mean and variance in one pass - // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm - ivec3 in_pos = lpos_to_pos(lpos, in_axis_map); - T width_counter = T(1); - - const bool has_unaligned_width = (width & 0x3) != 0; - const int fully_packed_4_comp_count = packed_width - mix(0, 1, has_unaligned_width); - - // iterate through texels that are fully packed ie. has 4 components - for (int w = 0; w < fully_packed_4_comp_count; ++w) { - in_pos[in_axis_map.x] = w; - VEC4_T v = load_texel(t_in, in_pos); - for (int i=0; i<4; i++) { - delta = v[i] - mean; - mean += delta / width_counter; - delta2 = v[i] - mean; - M2 += delta * delta2; - width_counter++; + T var = T(0); + const int remain = width & 3; + + const int in_pos_x_limit = out_limits[in_axis_map.x]; + + // Loop over the width in stride increments + for (int width_offset = 0; width_offset <= last_packed_width_index; width_offset += width_stride) { + // Read input in shared memory + for (int si = 0; si < SHARED_MEMORY_FACTOR; si++) { + const int in_pos_x = width_offset + int(gl_LocalInvocationID.x + si * gl_WorkGroupSize.x); + in_pos[in_axis_map.x] = in_pos_x; + + VEC4_T in_val = VEC4_T(0); + if (in_pos_x < in_pos_x_limit) { + in_val = load_texel(t_in, in_pos); + } + + if (in_pos_x == last_packed_width_index && remain != 0) { + const int remain_inv = 4 - remain; + in_val.y = mix(in_val.y, T(0), remain_inv > 2); + in_val.z = mix(in_val.z, T(0), remain_inv > 1); + in_val.w = mix(in_val.w, T(0), remain_inv > 0); + } + + shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = in_val; } + + reduce_input(width_stride, shared_idx_offset); + const VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)]; + mean += val.x + val.y + val.z + val.w; } - // handle last texel if its not 4 aligned - if (has_unaligned_width) { - in_pos[in_axis_map.x] = fully_packed_4_comp_count; - const int remaining_width = width & 0x3; - - VEC4_T v = load_texel(t_in, in_pos); - for (int i=0; i 2); + in_val.z = mix(in_val.z, mean.x, remain_inv > 1); + in_val.w = mix(in_val.w, mean.x, remain_inv > 0); + } + + const VEC4_T delta = in_val - mean; + const VEC4_T delta2 = delta * delta; + shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = delta2; } + + reduce_input(width_stride, shared_idx_offset); + const VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)]; + var += val.x + val.y + val.z + val.w; } - T var = M2 / (width_counter - 1); - T rstd = inversesqrt(var + epsilon); + var /= width; + + T rstd = pow(var + epsilon, T(-0.5)); T offset = -rstd * mean; - for (int w = 0; w < packed_width; ++w) { - in_pos[in_axis_map.x] = w; - VEC4_T v = load_texel(t_in, in_pos); - VEC4_T weight = load_texel(t_weight, ivec3(w, 0, 0)); - VEC4_T bias = load_texel(t_bias, ivec3(w, 0, 0)); - VEC4_T outtex = (v * rstd + offset) * weight + bias; - write_texel_lpos(t_out, ivec3(w, lpos.y, lpos.z), outtex, out_axis_map); + VEC4_T v = load_texel(t_in, lpos); + VEC4_T weight = load_texel(t_weight, ivec3(lpos.x, 0, 0)); + VEC4_T bias = load_texel(t_bias, ivec3(lpos.x, 0, 0)); + VEC4_T outtex = (v * rstd + offset) * weight + bias; + if (all(lessThan(lpos, out_limits))) { + write_texel_lpos(t_out, ivec3(lpos.x, lpos.y, lpos.z), outtex, out_axis_map); } - write_texel(t_mean, lpos, VEC4_T(mean)); - write_texel(t_rstd, lpos, VEC4_T(rstd)); + if (gl_GlobalInvocationID.x == 0) { + write_texel(t_mean, lpos, VEC4_T(mean)); + write_texel(t_rstd, lpos, VEC4_T(rstd)); + } } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/permute.glsl b/backends/vulkan/runtime/graph/ops/glsl/permute.glsl index 8a8703becd9..716c42e8ede 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/permute.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/permute.glsl @@ -31,6 +31,8 @@ layout(push_constant) uniform PRECISION restrict Block { layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; layout(constant_id = 3) const int packed_dim = C_DIM; +#extension GL_EXT_control_flow_attributes : require + void main() { ivec3 pos = ivec3(gl_GlobalInvocationID); @@ -54,11 +56,16 @@ void main() { in_bchw_pos[out_ndims[2]] = pos.y; in_bchw_pos[out_ndims[3]] = pos.x; - for (int j = 0; j < 4; ++j) { + const int in_packed_dim_size = in_sizes[3 - out_ndims[in_packed_dim_bchw_index]]; + + [[unroll]] for (int j = 0, bchw_index = in_bchw_pos[out_ndims[in_packed_dim_bchw_index]]; j < 4; ++j, ++bchw_index) { // terminate the loop if trying to access input texture out of bounds - if (any(greaterThanEqual(in_bchw_pos.wzyx, in_sizes.xyzw))) { + if (bchw_index >= in_packed_dim_size) { break; } + // go to position in the input, that is mapped to the packed dim in the output + in_bchw_pos[out_ndims[in_packed_dim_bchw_index]] = bchw_index; + ivec3 fetch_pos; fetch_pos.xy = in_bchw_pos.wz; @@ -74,9 +81,6 @@ void main() { // fetch input texel VEC4_T inval = VEC4_T(load_texel(t_in, fetch_pos)); outval[j] = inval[in_packed_dim_lane_index]; - - // go to next position in the input, that is mapped to the packed dim in the output - in_bchw_pos[out_ndims[in_packed_dim_bchw_index]]++; } pos[packed_dim] = int(gl_GlobalInvocationID[packed_dim]); diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl b/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl deleted file mode 100644 index 7350af3415c..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl +++ /dev/null @@ -1,133 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} - -#define T ${buffer_scalar_type(DTYPE)} -#define VEC4_T ${buffer_gvec_type(DTYPE, 4)} - -${define_required_extensions(DTYPE)} -$if WEIGHT_STORAGE == "buffer": - ${define_required_extensions("uint8")} - -layout(std430) buffer; - -${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_mat1", DTYPE, IN_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_qmat2", "uint8", WEIGHT_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_qparams", DTYPE, "buffer", is_scalar_array=False)} - -layout(push_constant) uniform restrict Block { - ivec4 out_sizes; - ivec4 mat1_sizes; - ivec4 qmat2_sizes; -}; - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -layout(constant_id = 3) const int group_size = 64; - -/* - * This shader computes a linear operator between a floating point input matrix - * x and a weights matrix that is quantized to 4 bits. - * - * The (W, H, C) shape of each tensor is: - * - x: (K, M) - * - weights: (N / 2, K) - * - The weights tensor has a data type of `uint8`. Each element in the tensor - * contains 2 4-bit values packed into a uint8. - * - See the pack_int4_linear_weight_transposed_interleave shader to see more - * details on how the weight tensor is stored. - * - qparams: (2, N, number_of_groups) - * - This tensor contains the scales and zeros quantization parameters for the - * weights tensor. The weight tensor is quantized group-wise, which means - * that every `group_size` elements along the K dimension of the weights - * tensor has independent quantization parameters. Along the width dim, the - * first value contains the scale for the group and the second value - * contains the zero point for the group. - * - * Note that this shader assumes that all tensors are width packed. - */ -void main() { - const uint out_row = gl_GlobalInvocationID.y; - // Each thread writes out 2 texels along the width axis, equivalent to 8 - // scalar elements. Therefore multiply the thread_idx.x by 8. - const uint out_col = gl_GlobalInvocationID.x << 3; - // Similar reasoning to the above, each thread works on 2 texels along the - // width axis so multiply thread_idx.x by 2. - const int out_col_texel_idx = int(gl_GlobalInvocationID.x) << 1; - - if (out_col >= out_sizes.x || out_row >= out_sizes.y) { - return; - } - - const int num_blocks = mat1_sizes.x / group_size; - - VEC4_T sums[2]; - - sums[0] = VEC4_T(0); - sums[1] = VEC4_T(0); - - VEC4_T scales[2]; - VEC4_T zeros[2]; - - $if WEIGHT_STORAGE == "buffer": - const int qmat2_stride = qmat2_sizes.x >> 2; - $if PARAMS_STORAGE == "buffer": - const int qparams_y_stride = out_sizes.x >> 2; - const int qparams_z_stride = qparams_y_stride * 2; - - for (int block_idx = 0; block_idx < num_blocks; ++block_idx) { - $if PARAMS_STORAGE == "buffer": - scales[0] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx]; - zeros[0] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + qparams_y_stride]; - - scales[1] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + 1]; - zeros[1] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + 1 + qparams_y_stride]; - $else: - scales[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 0, block_idx), 0); - zeros[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 1, block_idx), 0); - - scales[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 0, block_idx), 0); - zeros[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 1, block_idx), 0); - - for (int g_idx = 0; g_idx < group_size; g_idx += 4) { - const int k = block_idx * group_size + g_idx; - - $if IN_STORAGE == "buffer": - const VEC4_T mat1_tex = t_mat1[(out_row * mat1_sizes.x + k) >> 2]; - $else: - const VEC4_T mat1_tex = texelFetch(t_mat1, ivec3(k >> 2, out_row, 0), 0); - - for (int comp = 0; comp < 4; ++comp) { - $if WEIGHT_STORAGE == "buffer": - const u8vec4 packed_weight_tex = t_qmat2[(k + comp) * qmat2_stride + gl_GlobalInvocationID.x]; - $else: - const uvec4 packed_weight_tex = texelFetch( - t_qmat2, - ivec2(gl_GlobalInvocationID.x, k + comp), - 0); - - const uvec4 weight_tex_1 = (packed_weight_tex & 0xF0) >> 4; - const uvec4 weight_tex_2 = packed_weight_tex & 0x0F; - - sums[0] += mat1_tex[comp] * ((vec4(weight_tex_1) - 8.0) * scales[0] + zeros[0]); - sums[1] += mat1_tex[comp] * ((vec4(weight_tex_2) - 8.0) * scales[1] + zeros[1]); - } - } - } - - $if OUT_STORAGE == "buffer": - t_out[(out_row * out_sizes.x + out_col) >> 2] = sums[0]; - t_out[(out_row * out_sizes.x + out_col + 4) >> 2] = sums[1]; - $else: - imageStore(t_out, ivec3(out_col_texel_idx, out_row, 0), sums[0]); - imageStore(t_out, ivec3(out_col_texel_idx + 1, out_row, 0), sums[1]); -} diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl index 228e2e8f870..dfb5f1f2f9c 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl @@ -29,16 +29,20 @@ ${layout_declare_tensor(2, "r", "t_qmat2", "int8", STORAGE)} ${layout_declare_tensor(3, "r", "t_scales", DTYPE, STORAGE)} $if STORAGE == "buffer": - ${layout_declare_ubo(4, "ivec4", "out_sizes")} - ${layout_declare_ubo(5, "ivec4", "out_strides")} - ${layout_declare_ubo(6, "int", "out_numel")} - ${layout_declare_ubo(7, "ivec4", "mat1_sizes")} - ${layout_declare_ubo(8, "ivec4", "mat1_strides")} - ${layout_declare_ubo(9, "ivec4", "qmat2_strides")} - ${layout_declare_ubo(10, "ivec4", "scales_strides")} + layout(push_constant) uniform restrict Block { + ivec4 out_sizes; + ivec4 out_strides; + ivec4 mat1_sizes; + ivec4 mat1_strides; + ivec4 qmat2_strides; + ivec4 scales_strides; + int out_numel; + }; $else: - ${layout_declare_ubo(4, "ivec3", "out_limits")} - ${layout_declare_ubo(5, "ivec4", "mat1_sizes")} + layout(push_constant) uniform restrict Block { + ivec3 out_limits; + ivec4 mat1_sizes; + }; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; @@ -83,42 +87,40 @@ void main() { #else // USING_TEXTURE -#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require - void main() { - const u16vec2 out_pos = u16vec2( - gl_GlobalInvocationID.x, - gl_GlobalInvocationID.y); + const ivec2 out_pos = ivec2( + gl_GlobalInvocationID.x % out_limits.x, + gl_GlobalInvocationID.x / out_limits.x); - if (out_pos.x >= out_limits.x || out_pos.y >= out_limits.y) { + if (out_pos.y >= out_limits.y) { return; } - const uint16_t qmat2_pos_x = out_pos.x; + const int qmat2_pos_x = out_pos.x; VEC4_T outtex = VEC4_T(0); - const VEC4_T scales = load_texel(t_scales, u16vec3(out_pos.x, 0, 0)); + const VEC4_T scales = load_texel(t_scales, ivec3(out_pos.x, 0, 0)); VEC4_T mat1_tex; VEC4_T mat2_tex[4]; for ( - uint16_t i = uint16_t(0), x = uint16_t(0); - i < uint16_t(mat1_sizes.x); - i += uint16_t(4), x++) + int i = 0, x = 0; + i < mat1_sizes.x; + i += 4, x++) { - mat1_tex = load_texel(t_mat1, u16vec3(x, out_pos.y, 0)); + mat1_tex = load_texel(t_mat1, ivec3(x, out_pos.y, 0)); - mat2_tex[0] = load_texel(t_qmat2, u16vec3(out_pos.x, i, 0)); - mat2_tex[1] = load_texel(t_qmat2, u16vec3(out_pos.x, i + uint16_t(1), 0)); - mat2_tex[2] = load_texel(t_qmat2, u16vec3(out_pos.x, i + uint16_t(2), 0)); - mat2_tex[3] = load_texel(t_qmat2, u16vec3(out_pos.x, i + uint16_t(3), 0)); + mat2_tex[0] = load_texel(t_qmat2, ivec3(out_pos.x, i, 0)); + mat2_tex[1] = load_texel(t_qmat2, ivec3(out_pos.x, i + 1, 0)); + mat2_tex[2] = load_texel(t_qmat2, ivec3(out_pos.x, i + 2, 0)); + mat2_tex[3] = load_texel(t_qmat2, ivec3(out_pos.x, i + 3, 0)); outtex += mat1_tex.x * mat2_tex[0] + mat1_tex.y * mat2_tex[1] + mat1_tex.z * mat2_tex[2] + mat1_tex.w * mat2_tex[3]; } outtex *= scales; - write_texel(t_out, u16vec3(out_pos, 0), outtex); + write_texel(t_out, ivec3(out_pos, 0), outtex); } #endif diff --git a/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp b/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp index 7aa98e52654..f2e8eff763a 100644 --- a/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp @@ -83,8 +83,19 @@ void add_native_layer_norm_node( std::vector in_sizes = t_input->sizes(); - utils::uvec3 global_size = t_mean->logical_limits(); - utils::uvec3 local_size = adaptive_work_group_size(global_size); + utils::uvec3 global_size = t_out->logical_limits(); + utils::uvec3 local_size; + + // Since the shader sets shared memory scale factor > 1, if dispatch is + // greater than maximum WG size. Setting WG size in X axis to max WG size, + // would allow best thread utilization. + if (global_size[0] > 64) { + local_size = {64, 1, 1}; + } else { + // If thread size in X axis is smaller or equal to maximum WG size, we can + // let the function decide the best WG size. + local_size = graph.create_local_wg_size(global_size); + } std::string kernel_name("native_layer_norm"); kernel_name.reserve(kShaderNameReserve); diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt8.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt8.cpp index 49085ff4e06..5054b2e5e9c 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt8.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearInt8.cpp @@ -98,47 +98,25 @@ void add_q_8w_linear_node( add_dtype_suffix(kernel_name, graph.dtype_of(out_W_packed)); add_storage_type_suffix(kernel_name, graph.storage_type_of(out_W_packed)); - vkapi::ParamsBindList ubos({}); + std::vector pcs; if (graph.is_buffer_storage(out_W_packed)) { - ubos.append( - {graph.sizes_ubo(out_W_packed), - graph.strides_ubo(out_W_packed), - graph.numel_ubo(out_W_packed), - graph.sizes_ubo(mat1_W_packed), - graph.strides_ubo(mat1), - graph.strides_ubo(q_mat2), - graph.strides_ubo(scales)}); + pcs = { + graph.sizes_pc_of(out_W_packed), + graph.strides_pc_of(out_W_packed), + graph.sizes_pc_of(mat1_W_packed), + graph.strides_pc_of(mat1), + graph.strides_pc_of(q_mat2), + graph.strides_pc_of(scales), + graph.numel_pc_of(out_W_packed)}; } else { - ubos.append( - {graph.logical_limits_ubo(out_W_packed), - graph.sizes_ubo(mat1_W_packed)}); + pcs = { + graph.logical_limits_pc_of(out_W_packed), + graph.sizes_pc_of(mat1_W_packed)}; } - utils::uvec3 global_wg; - if (graph.is_buffer_storage(out)) { - global_wg = {static_cast(graph.numel_of(out_W_packed)), 1, 1}; - } else { - global_wg = graph.logical_limits_of(out_W_packed); - } - - utils::uvec3 local_wg{8, 8, 1}; - int32_t out_W = graph.size_at(-1, out_W_packed); - - if (graph.is_buffer_storage(out_W_packed)) { - local_wg[0] = 64; - local_wg[1] = 1; - local_wg[2] = 1; - } else { - if (out_W % 8 != 0) { - if (out_W % 4 == 0) { - local_wg[0] = 4; - local_wg[1] = 16; - } else { - local_wg[0] = 2; - local_wg[1] = 32; - } - } - } + const utils::uvec3 global_wg = { + static_cast(graph.numel_of(out_W_packed)), 1, 1}; + const utils::uvec3 local_wg{64, 1, 1}; graph.execute_nodes().emplace_back(new DispatchNode( graph, @@ -149,11 +127,13 @@ void add_q_8w_linear_node( {{out_W_packed, vkapi::MemoryAccessType::WRITE}, {{mat1_W_packed, q_mat2, scales}, vkapi::MemoryAccessType::READ}}, // Shader params buffers - ubos, + {}, // Specialization Constants {}, // Resizing Logic - resize_q_8w_linear_node)); + resize_q_8w_linear_node, + {}, + pcs)); if (!graph.is_buffer_storage(out) && graph.packed_dim_of(out) != WHCN::kWidthDim) { viewFn(graph, {out_W_packed, graph.add_none(), out});