diff --git a/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp b/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp index 704601a0271..f2e8eff763a 100644 --- a/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp @@ -84,7 +84,18 @@ void add_native_layer_norm_node( std::vector in_sizes = t_input->sizes(); utils::uvec3 global_size = t_out->logical_limits(); - utils::uvec3 local_size = graph.create_local_wg_size(global_size); + utils::uvec3 local_size; + + // Since the shader sets shared memory scale factor > 1, if dispatch is + // greater than maximum WG size. Setting WG size in X axis to max WG size, + // would allow best thread utilization. + if (global_size[0] > 64) { + local_size = {64, 1, 1}; + } else { + // If thread size in X axis is smaller or equal to maximum WG size, we can + // let the function decide the best WG size. + local_size = graph.create_local_wg_size(global_size); + } std::string kernel_name("native_layer_norm"); kernel_name.reserve(kShaderNameReserve);