From 83d0bbfc389d45fd50cb2f1e2cf81f5466935803 Mon Sep 17 00:00:00 2001 From: Francois Date: Thu, 30 Jan 2020 01:25:51 +0100 Subject: [PATCH 01/21] Adds spectral functions --- examples/fft_benchmark.py | 149 +++++++++++++++++ mesh_tensorflow/ops.py | 19 ++- .../ops_with_redefined_builtins.py | 2 +- mesh_tensorflow/signal_ops.py | 158 ++++++++++++++++++ 4 files changed, 325 insertions(+), 3 deletions(-) create mode 100644 examples/fft_benchmark.py create mode 100644 mesh_tensorflow/signal_ops.py diff --git a/examples/fft_benchmark.py b/examples/fft_benchmark.py new file mode 100644 index 00000000..464ab933 --- /dev/null +++ b/examples/fft_benchmark.py @@ -0,0 +1,149 @@ +""" +Benchmark script for studying the scaling of distributed FFTs on Mesh Tensorflow +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time +import tensorflow.compat.v1 as tf +import mesh_tensorflow as mtf + +from tensorflow.python.tpu import tpu_config # pylint: disable=g-direct-tensorflow-import +from tensorflow.python.tpu import tpu_estimator # pylint: disable=g-direct-tensorflow-import +from tensorflow_estimator.python.estimator import estimator as estimator_lib + +# Cloud TPU Cluster Resolver flags +tf.flags.DEFINE_string( + "tpu", default=None, + help="The Cloud TPU to use for training. This should be either the name " + "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " + "url.") +tf.flags.DEFINE_string( + "tpu_zone", default=None, + help="[Optional] GCE zone where the Cloud TPU is located in. If not " + "specified, we will attempt to automatically detect the GCE project from " + "metadata.") +tf.flags.DEFINE_string( + "gcp_project", default=None, + help="[Optional] Project name for the Cloud TPU-enabled project. If not " + "specified, we will attempt to automatically detect the GCE project from " + "metadata.") + +tf.flags.DEFINE_string("model_dir", None, "Estimator model_dir") + +tf.flags.DEFINE_integer("cube_size", 512, "Size of the 3D volume.") +tf.flags.DEFINE_integer("batch_size", 128, + "Mini-batch size for the training. Note that this " + "is the global batch size and not the per-shard batch.") + +tf.flags.DEFINE_string("mesh_shape", "b1:32", "mesh shape") +tf.flags.DEFINE_string("layout", "nx:b1,tny:b1", "layout rules") + +FLAGS = tf.flags.FLAGS + +def benchmark_model(mesh): + """ + Initializes a 3D volume with random noise, and execute a forward FFT + """ + batch_dim = mtf.Dimension("batch", FLAGS.batch_size) + + # Declares real space dimensions + x_dim = mtf.Dimension("nx", FLAGS.cube_size) + y_dim = mtf.Dimension("ny", FLAGS.cube_size) + z_dim = mtf.Dimension("nz", FLAGS.cube_size) + + # Declares Fourier space dimensions + tx_dim = mtf.Dimension("tnx", FLAGS.cube_size) + ty_dim = mtf.Dimension("tny", FLAGS.cube_size) + tz_dim = mtf.Dimension("tnz", FLAGS.cube_size) + + # Create field + field = mtf.random_uniform(mesh, [batch_dim, x_dim, y_dim, z_dim]) + + # Apply FFT + fft_field = mtf.signal.fft3d(mtf.cast(field, tf.complex64), [tx_dim, ty_dim, tz_dim]) + + # Inverse FFT + rfield = mtf.cast(mtf.signal.ifft3d(fft_field, [x_dim, y_dim, z_dim]), tf.float32) + + # Compute errors + err = mtf.reduce_max(mtf.abs(field - rfield)) + return err + +def model_fn(features, labels, mode, params): + """A model is called by TpuEstimator.""" + del labels + del features + + mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) + layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) + + ctx = params['context'] + num_hosts = ctx.num_hosts + host_placement_fn = ctx.tpu_host_placement_function + device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)] + tf.logging.info('device_list = %s' % device_list,) + + mesh_devices = [''] * mesh_shape.size + mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( + mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) + + graph = mtf.Graph() + mesh = mtf.Mesh(graph, "fft_mesh") + + with mtf.utils.outside_all_rewrites(): + err = benchmark_model(mesh) + + lowering = mtf.Lowering(graph, {mesh: mesh_impl}) + + tf_err = tf.to_float(lowering.export_to_tf_tensor(err)) + + with mtf.utils.outside_all_rewrites(): + return tpu_estimator.TPUEstimatorSpec(mode, loss=tf_err) + + +def main(_): + + tf.logging.set_verbosity(tf.logging.INFO) + mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) + + # Resolve the TPU environment + tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( + FLAGS.tpu, + zone=FLAGS.tpu_zone, + project=FLAGS.gcp_project + ) + + run_config = tf.estimator.tpu.RunConfig( + cluster=tpu_cluster_resolver, + model_dir=FLAGS.model_dir, + save_checkpoints_steps=None, # Disable the default saver + save_checkpoints_secs=None, # Disable the default saver + log_step_count_steps=100, + save_summary_steps=100, + tpu_config=tpu_config.TPUConfig( + num_shards=mesh_shape.size, + iterations_per_loop=100, + num_cores_per_replica=1, + per_host_input_for_training=tpu_config.InputPipelineConfig.BROADCAST)) + + model = tpu_estimator.TPUEstimator( + use_tpu=True, + model_fn=model_fn, + config=run_config, + train_batch_size=FLAGS.batch_size, + eval_batch_size=FLAGS.batch_size) + + def dummy_input_fn(params): + """Dummy input function """ + return tf.zeros(shape=[params['batch_size']], dtype=tf.float32), tf.zeros(shape=[params['batch_size']], dtype=tf.float32) + + # Run evaluate loop for ever, we will be connecting to this process using a profiler + model.evaluate(input_fn=dummy_input_fn, steps=100000) + +if __name__ == "__main__": + tf.disable_v2_behavior() + tf.logging.set_verbosity(tf.logging.INFO) + tf.app.run() diff --git a/mesh_tensorflow/ops.py b/mesh_tensorflow/ops.py index a0783fb0..25dda4e2 100644 --- a/mesh_tensorflow/ops.py +++ b/mesh_tensorflow/ops.py @@ -1472,8 +1472,8 @@ def to_string(self): @property def has_gradient(self): return ( - [t for t in self.inputs if t.dtype.is_floating] and - [t for t in self.outputs if t.dtype.is_floating]) + [t for t in self.inputs if t.dtype.is_floating or t.dtype.is_complex] and + [t for t in self.outputs if t.dtype.is_floating or t.dtype.is_complex]) def gradient(self, unused_grad_ys): raise NotImplementedError("Gradient not implemented") @@ -5583,6 +5583,21 @@ def random_uniform(mesh, shape, **kwargs): return RandomOperation(mesh, shape, tf.random.uniform, **kwargs).outputs[0] +def random_normal(mesh, shape, **kwargs): + """Random normal. + + Args: + mesh: a Mesh + shape: a Shape + **kwargs: keyword args for tf.random.normal, except seed + + Returns: + a Tensor + """ + shape = mtf.convert_to_shape(shape) + return mtf.RandomOperation(mesh, shape, tf.random.normal, **kwargs).outputs[0] + + def dropout(x, keep_prob=None, rate=None, noise_shape=None, name=None): """Randomly set some elements to 0 and scale up the rest. diff --git a/mesh_tensorflow/ops_with_redefined_builtins.py b/mesh_tensorflow/ops_with_redefined_builtins.py index 50848c6e..9b5ef108 100644 --- a/mesh_tensorflow/ops_with_redefined_builtins.py +++ b/mesh_tensorflow/ops_with_redefined_builtins.py @@ -24,7 +24,7 @@ from mesh_tensorflow.ops import mtf_pow as pow # pylint: disable=redefined-builtin,unused-import from mesh_tensorflow.ops import mtf_range as range # pylint: disable=redefined-builtin,unused-import from mesh_tensorflow.ops import mtf_slice as slice # pylint: disable=redefined-builtin,unused-import - +import mesh_tensorflow.signal_ops as signal # TODO(trandustin): Seal module. diff --git a/mesh_tensorflow/signal_ops.py b/mesh_tensorflow/signal_ops.py new file mode 100644 index 00000000..011f0969 --- /dev/null +++ b/mesh_tensorflow/signal_ops.py @@ -0,0 +1,158 @@ +"""Spectral ops for Mesh TensorFlow.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import functools +import itertools +import operator +import os +import re + +from mesh_tensorflow import utils +import numpy as np +import six +from six.moves import xrange # pylint: disable=redefined-builtin + +import tensorflow.compat.v1 as tf + +from mesh_tensorflow import ops_with_redefined_builtins as mtf + +class FFT3DOperation(mtf.Operation): + """ + Computes the 3-dimensional discrete Fourier transform over the inner-most 3 + dimensions of input tensor. Note that the output FFT is transposed. + + Args: + input: A Tensor. Must be one of the following types: complex64, complex128 + freq_dims: List of 3 Dimensions representing the frequency dimensions. + name: A name for the operation (optional). + + Returns: + A Tensor of shape `input.shape[:-3] + freq_dims`. + """ + def __init__(self, input, freq_dims, name=None): + super(FFT3DOperation, self).__init__([input], name=name or "FFT3D") + self._freq_dims = freq_dims + self._output_shape = mtf.Shape(input.shape[:-3]+[freq_dims[1], freq_dims[2], freq_dims[0]]) + self._outputs = [mtf.Tensor(self, mtf.Shape(self._output_shape), input.dtype)] + + def gradient(self, grad_ys): + dy = grad_ys[0] + x = self.inputs[0] + return [ifft3d(dy, x.shape[-3:])] + + def lower(self, lowering): + mesh_impl = lowering.mesh_impl(self) + x = self.inputs[0] + naxes = len(x.shape) + slices = lowering.tensors[self.inputs[0]] + # Before performing any operations, we check the splitting + split_axes = [] + for i in range(3): + split_axes.append(mesh_impl.tensor_dimension_to_mesh_axis(x.shape.dims[-3:][i])) + + # Perform FFT followed by tranposes + for i in range(2): + # Apply FFT along last axis + slices = mesh_impl.slicewise(tf.spectral.fft, slices) + + # Before transposing the array, making sure the new last dimension will + # be contiguous + if split_axes[-2] is not None: + slices = mesh_impl.alltoall(slices, split_axes[-2], naxes-1, naxes-2) + split_axes[-1] = split_axes[-2] + split_axes[-2] = None + perm = np.arange(len(x.shape)) + perm[-3:] = np.roll(perm[-3:], shift=1) + slices = mesh_impl.slicewise(lambda x: tf.transpose(x, perm), slices) + split_axes = [split_axes[2], split_axes[0], split_axes[1]] + + # Apply FFT along last axis + slices = mesh_impl.slicewise(tf.spectral.fft, slices) + lowering.set_tensor_lowering(self.outputs[0], slices) + +def fft3d(x, freq_dims, name=None): + """ + Computes the 3-dimensional discrete Fourier transform over the inner-most 3 + dimensions of input tensor. Note that the output FFT is transposed. + + Args: + input: A Tensor. Must be one of the following types: complex64, complex128 + freq_dims: List of 3 Dimensions representing the frequency dimensions. + name: A name for the operation (optional). + + Returns: + A Tensor of shape `input.shape[:-3] + freq_dims`. + """ + return FFT3DOperation(x, freq_dims, name).outputs[0] + +class iFFT3DOperation(mtf.Operation): + """ + Computes the inverse 3-dimensional discrete Fourier transform over the inner-most 3 + dimensions of input tensor. Note that the input FFT is assumed transposed. + + Args: + input: A Tensor. Must be one of the following types: complex64, complex128 + dims: List of 3 Dimensions representing the direct space dimensions. + name: A name for the operation (optional). + + Returns: + A Tensor of shape `input.shape[:-3] + dims`. + """ + def __init__(self, input, dims, name=None): + super(iFFT3DOperation, self).__init__([input], name=name or "iFFT3D") + self._dims = dims + self._output_shape = mtf.Shape(input.shape[:-3]+dims) + self._outputs = [mtf.Tensor(self, mtf.Shape(self._output_shape), input.dtype)] + + def gradient(self, grad_ys): + dy = grad_ys[0] + ky, kz, kx = self.inputs[0].shape[-3:] + return [fft3d(dy, [kx, ky, kz])] + + def lower(self, lowering): + mesh_impl = lowering.mesh_impl(self) + x = self.inputs[0] + naxes = len(x.shape) + slices = lowering.tensors[self.inputs[0]] + # Before performing any operations, we check the splitting + split_axes = [] + for i in range(3): + split_axes.append(mesh_impl.tensor_dimension_to_mesh_axis(x.shape.dims[-3:][i])) + + # Perform FFT followed by tranposes + for i in range(2): + # Apply FFT along last axis + slices = mesh_impl.slicewise(tf.spectral.ifft, slices) + + # Before transposing the array, making sure the new last dimension will + # be contiguous + if split_axes[0] is not None: + slices = mesh_impl.alltoall(slices, split_axes[0], naxes-1, naxes-3) + split_axes[-1] = split_axes[0] + split_axes[0] = None + perm = np.arange(len(x.shape)) + perm[-3:] = np.roll(perm[-3:], shift=-1) + slices = mesh_impl.slicewise(lambda x: tf.transpose(x, perm), slices) + split_axes = [split_axes[1], split_axes[2], split_axes[0]] + + # Apply FFT along last axis + slices = mesh_impl.slicewise(tf.spectral.ifft, slices) + lowering.set_tensor_lowering(self.outputs[0], slices) + +def ifft3d(x, dims, name=None): + """ + Computes the inverse 3-dimensional discrete Fourier transform over the inner-most 3 + dimensions of input tensor. Note that the input FFT is assumed transposed. + + Args: + input: A Tensor. Must be one of the following types: complex64, complex128 + dims: List of 3 Dimensions representing the direct space dimensions. + name: A name for the operation (optional). + + Returns: + A Tensor of shape `input.shape[:-3] + dims`. + """ + return iFFT3DOperation(x, dims, name).outputs[0] From 86d9991cb36b7bae70020ab1664a47a706cacbf4 Mon Sep 17 00:00:00 2001 From: Zaccharie Ramzi Date: Wed, 25 Nov 2020 15:54:27 +0100 Subject: [PATCH 02/21] added complex manipulation ops --- mesh_tensorflow/ops.py | 71 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/mesh_tensorflow/ops.py b/mesh_tensorflow/ops.py index 03535e52..1df39936 100644 --- a/mesh_tensorflow/ops.py +++ b/mesh_tensorflow/ops.py @@ -6600,3 +6600,74 @@ def nth_largest_element(x, n, reduced_dim, name=None): def nth_smallest_element(x, n, reduced_dim, name=None): return -nth_largest_element(-x, n, reduced_dim, name=name) + +def to_complex(x, complex_dim=None): + """Gathers the real and imaginary of a tensor in a complex tensor + + Args: + x: a float Tensor + complex_dim: a Dimension where both the real and imaginary parts of the + tensor are. Defaults to None, which corresponds to the last + dimension of the tensor. + Returns: + a Tensor, complex-valued + """ + if complex_dim is None: + complex_dim = x.shape[-1] + x_real, x_imag = split(x, complex_dim, 2) + x_real = cast(x_real, tf.complex64) + x_imag = cast(x_imag, tf.complex64) + x_complex = x_real + 1j * x_imag + return x_complex + +def split_complex(x, complex_dim=None): + """Splits a complex tensor into real and imaginary, concatenated + + Args: + x: a float Tensor + complex_dim: a Dimension where you want the split to happen. + Defaults to None, which corresponds to the last dimension of the tensor. + Returns: + a Tensor, float-valued + """ + op = SplitComplexOperation(x, complex_dim=complex_dim) + output = op.outputs[0] + return output + +class SplitComplexOperation(Operation): + def __init__(self, split_input, complex_dim=None, name=None): + super().__init__([split_input], name=name or 'split_complex') + if complex_dim is None: + self._split_dim = split_input.shape.dims[-1] + self._split_axis = -1 + else: + self._split_dim = complex_dim + self._split_axis = split_input.shape.index(complex_dim) + self._splittable_dims, self._unsplittable_dims = ( + self._initialize_splittable_and_unsplittable_dims( + "splittable", [self._split_dim.name], + ) + ) + output_shape = split_input.shape.resize_dimension( + self._split_dim.name, + self._split_dim.size*2, + ) + self._outputs = [Tensor(self, output_shape, tf.float32)] + + def gradient(self, grad_ys): + dy = grad_ys[0] + dy_complex = to_complex(dy) + return [dy_complex] + + def lower(self, lowering): + mesh_impl = lowering.mesh_impl(self) + split_input = self.inputs[0] + if mesh_impl.tensor_dimension_to_mesh_axis(self._split_dim) is not None: + raise ValueError("can't slice along complex split dimension") + def tf_fn(tf_input): + tf_real = tf.math.real(tf_input) + tf_imag = tf.math.imag(tf_input) + output = tf.concat([tf_real, tf_imag], axis=self._split_axis) + return output + y = mesh_impl.slicewise(tf_fn, lowering.tensors[split_input]) + lowering.set_tensor_lowering(self.outputs[0], y) From 87510bccaa18b6b966d0752092d2a7350830702a Mon Sep 17 00:00:00 2001 From: Zaccharie Ramzi Date: Wed, 25 Nov 2020 16:11:24 +0100 Subject: [PATCH 03/21] added complex manipulation ops tests --- mesh_tensorflow/ops_test.py | 45 +++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/mesh_tensorflow/ops_test.py b/mesh_tensorflow/ops_test.py index 168b588d..d04d258b 100644 --- a/mesh_tensorflow/ops_test.py +++ b/mesh_tensorflow/ops_test.py @@ -622,6 +622,51 @@ def x_squared_plus_x(x): self.evaluate(expected_dx)) +class ComplexManipulationTest(tf.test.TestCase): + def setUP(self): + super(ComplexManipulationTest, self).setUp() + self.graph = mtf.Graph() + self.mesh = mtf.Mesh(self.graph, "my_mesh") + + def test_to_complex(self): + tensor = tf.random.normal([1, 10, 4]) + mtf_shape = [mtf.Dimension(f'dim_{i}', s) for i, s in enumerate(tensor.shape)] + tensor_mesh = mtf.import_tf_tensor(self.mesh, tensor, shape=mtf_shape) + outputs = mtf.to_complex(tensor_mesh) + assert outputs.dtype == tf.complex64 + assert len(outputs.shape) == 3 + assert outputs.shape[-1].size == 2 + assert [s.size for s in outputs.shape[:-1]] == [s.size for s in tensor_mesh.shape[:-1]] + mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( + shape=[], layout={}, devices=[""]) + lowering = mtf.Lowering(self.graph, {self.mesh: mesh_impl}) + outputs_tf = lowering.export_to_tf_tensor(outputs) + self.assertAllEqual( + outputs_tf, + tf.complex(tensor[..., 0:2], tensor[..., 2:4]), + ) + + def test_split_complex(self): + tensor = tf.complex( + tf.random.normal([1, 10, 2]), + tf.random.normal([1, 10, 2]), + ) + mtf_shape = [mtf.Dimension(f'dim_{i}', s) for i, s in enumerate(tensor.shape)] + tensor_mesh = mtf.import_tf_tensor(self.mesh, tensor, shape=mtf_shape) + outputs = mtf.split_complex(tensor_mesh) + assert outputs.dtype == tf.float32 + assert len(outputs.shape) == 3 + assert outputs.shape[-1].size == 4 + assert [s.size for s in outputs.shape[:-1]] == [s.size for s in tensor_mesh.shape[:-1]] + mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( + shape=[], layout={}, devices=[""]) + lowering = mtf.Lowering(self.graph, {self.mesh: mesh_impl}) + outputs_tf = lowering.export_to_tf_tensor(outputs) + self.assertAllEqual( + outputs_tf, + tf.concat([tf.math.real(tensor), tf.math.imag(tensor)], axis=-1), + ) + if __name__ == "__main__": tf.disable_v2_behavior() tf.enable_eager_execution() From 159027ae2073cc919b397f747adcac10856300b2 Mon Sep 17 00:00:00 2001 From: Zaccharie Ramzi Date: Wed, 25 Nov 2020 16:14:51 +0100 Subject: [PATCH 04/21] cleaned signal ops imports --- mesh_tensorflow/signal_ops.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/mesh_tensorflow/signal_ops.py b/mesh_tensorflow/signal_ops.py index 011f0969..ebb6cc3b 100644 --- a/mesh_tensorflow/signal_ops.py +++ b/mesh_tensorflow/signal_ops.py @@ -3,22 +3,13 @@ from __future__ import division from __future__ import print_function -import collections -import functools -import itertools -import operator -import os -import re - -from mesh_tensorflow import utils import numpy as np -import six -from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow.compat.v1 as tf from mesh_tensorflow import ops_with_redefined_builtins as mtf + class FFT3DOperation(mtf.Operation): """ Computes the 3-dimensional discrete Fourier transform over the inner-most 3 From 6f85ab6bdad2d149c1cf25ef4092be0f345e54b8 Mon Sep 17 00:00:00 2001 From: Zaccharie Ramzi Date: Wed, 25 Nov 2020 16:30:21 +0100 Subject: [PATCH 05/21] corrected complex ops test naming --- mesh_tensorflow/ops_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mesh_tensorflow/ops_test.py b/mesh_tensorflow/ops_test.py index d04d258b..cc00290e 100644 --- a/mesh_tensorflow/ops_test.py +++ b/mesh_tensorflow/ops_test.py @@ -623,12 +623,12 @@ def x_squared_plus_x(x): class ComplexManipulationTest(tf.test.TestCase): - def setUP(self): + def setUp(self): super(ComplexManipulationTest, self).setUp() self.graph = mtf.Graph() self.mesh = mtf.Mesh(self.graph, "my_mesh") - def test_to_complex(self): + def testToComplex(self): tensor = tf.random.normal([1, 10, 4]) mtf_shape = [mtf.Dimension(f'dim_{i}', s) for i, s in enumerate(tensor.shape)] tensor_mesh = mtf.import_tf_tensor(self.mesh, tensor, shape=mtf_shape) @@ -646,7 +646,7 @@ def test_to_complex(self): tf.complex(tensor[..., 0:2], tensor[..., 2:4]), ) - def test_split_complex(self): + def testSplitComplex(self): tensor = tf.complex( tf.random.normal([1, 10, 2]), tf.random.normal([1, 10, 2]), From 0e2760d068c4f41dedc98a66010c698d89fc732c Mon Sep 17 00:00:00 2001 From: Zaccharie Ramzi Date: Wed, 25 Nov 2020 16:30:30 +0100 Subject: [PATCH 06/21] added tests for signal ops --- mesh_tensorflow/signal_ops_test.py | 53 ++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 mesh_tensorflow/signal_ops_test.py diff --git a/mesh_tensorflow/signal_ops_test.py b/mesh_tensorflow/signal_ops_test.py new file mode 100644 index 00000000..64c4620c --- /dev/null +++ b/mesh_tensorflow/signal_ops_test.py @@ -0,0 +1,53 @@ +import mesh_tensorflow as mtf +from mesh_tensorflow.signal_ops import fft3d, ifft3d +import tensorflow as tf + + +class FFTTest(tf.test.TestCase): + def setUp(self): + super(FFTTest, self).setUp() + self.graph = mtf.Graph() + self.mesh = mtf.Mesh(self.graph, "my_mesh") + volume_size = 32 + batch_dim = mtf.Dimension("batch", 1) + cols_dim = mtf.Dimension("cols", volume_size) + volume_channels_dim = mtf.Dimension('channels', 1) + slices_dim = mtf.Dimension("slices", volume_size) + rows_dim = mtf.Dimension("rows", volume_size) + self.shape = [batch_dim, slices_dim, rows_dim, cols_dim, volume_channels_dim] + volume_shape = [d.size for d in self.shape] + self.volume = tf.random.normal(volume_shape) + self.volume_mesh = mtf.import_tf_tensor(self.mesh, self.volume, shape=self.shape) + + + def testFft3d(self): + outputs = fft3d(self.volume_mesh, freq_dims=self.shape[1:4]) + assert len(outputs.shape) == 4 + assert outputs.dtype == tf.complex64 + assert outputs.shape == self.shape + mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( + shape=[], layout={}, devices=[""]) + lowering = mtf.Lowering(self.graph, {self.mesh: mesh_impl}) + outputs_tf = lowering.export_to_tf_tensor(outputs) + tf_tester = tf.test.TestCase() + expected_outputs = tf.signal.fft3d(self.volume) + tf_tester.assertAllEqual( + outputs_tf, + expected_outputs, + ) + + def testIfft3d(self): + outputs = ifft3d(self.volume_mesh, dims=self.shape[1:4]) + assert len(outputs.shape) == 4 + assert outputs.dtype == tf.complex64 + assert outputs.shape == self.shape + mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( + shape=[], layout={}, devices=[""]) + lowering = mtf.Lowering(self.graph, {self.mesh: mesh_impl}) + outputs_tf = lowering.export_to_tf_tensor(outputs) + tf_tester = tf.test.TestCase() + expected_outputs = tf.signal.ifft3d(self.volume) + tf_tester.assertAllEqual( + outputs_tf, + expected_outputs, + ) From 14dda78e43fe3c75f9557bbca6d8cb360502b2b3 Mon Sep 17 00:00:00 2001 From: Zaccharie Ramzi Date: Wed, 25 Nov 2020 16:33:16 +0100 Subject: [PATCH 07/21] corrected volume shape in signal ops tests --- mesh_tensorflow/signal_ops_test.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mesh_tensorflow/signal_ops_test.py b/mesh_tensorflow/signal_ops_test.py index 64c4620c..a3d4a65b 100644 --- a/mesh_tensorflow/signal_ops_test.py +++ b/mesh_tensorflow/signal_ops_test.py @@ -10,11 +10,10 @@ def setUp(self): self.mesh = mtf.Mesh(self.graph, "my_mesh") volume_size = 32 batch_dim = mtf.Dimension("batch", 1) - cols_dim = mtf.Dimension("cols", volume_size) - volume_channels_dim = mtf.Dimension('channels', 1) slices_dim = mtf.Dimension("slices", volume_size) rows_dim = mtf.Dimension("rows", volume_size) - self.shape = [batch_dim, slices_dim, rows_dim, cols_dim, volume_channels_dim] + cols_dim = mtf.Dimension("cols", volume_size) + self.shape = [batch_dim, slices_dim, rows_dim, cols_dim,] volume_shape = [d.size for d in self.shape] self.volume = tf.random.normal(volume_shape) self.volume_mesh = mtf.import_tf_tensor(self.mesh, self.volume, shape=self.shape) From 69b2d76a7c2eb4f6493f04b95296874bafe63246 Mon Sep 17 00:00:00 2001 From: Zaccharie Ramzi Date: Wed, 25 Nov 2020 17:03:35 +0100 Subject: [PATCH 08/21] corrected shaping for signal ops tests --- mesh_tensorflow/signal_ops_test.py | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/mesh_tensorflow/signal_ops_test.py b/mesh_tensorflow/signal_ops_test.py index a3d4a65b..8350dcf1 100644 --- a/mesh_tensorflow/signal_ops_test.py +++ b/mesh_tensorflow/signal_ops_test.py @@ -15,7 +15,10 @@ def setUp(self): cols_dim = mtf.Dimension("cols", volume_size) self.shape = [batch_dim, slices_dim, rows_dim, cols_dim,] volume_shape = [d.size for d in self.shape] - self.volume = tf.random.normal(volume_shape) + self.volume = tf.complex( + tf.random.normal(volume_shape), + tf.random.normal(volume_shape), + ) self.volume_mesh = mtf.import_tf_tensor(self.mesh, self.volume, shape=self.shape) @@ -23,30 +26,40 @@ def testFft3d(self): outputs = fft3d(self.volume_mesh, freq_dims=self.shape[1:4]) assert len(outputs.shape) == 4 assert outputs.dtype == tf.complex64 - assert outputs.shape == self.shape + # assert outputs.shape == mtf.Shape(self.shape) + assert [d.size for d in outputs.shape] == [d.size for d in self.shape] mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(self.graph, {self.mesh: mesh_impl}) outputs_tf = lowering.export_to_tf_tensor(outputs) - tf_tester = tf.test.TestCase() expected_outputs = tf.signal.fft3d(self.volume) - tf_tester.assertAllEqual( + expected_outputs = tf.transpose(expected_outputs, perm=[0, 2, 3, 1]) + self.assertAllClose( outputs_tf, expected_outputs, + rtol=1e-4, + atol=1e-4, ) def testIfft3d(self): - outputs = ifft3d(self.volume_mesh, dims=self.shape[1:4]) + outputs = ifft3d( + self.volume_mesh, + # ordering is not the same for ifft3d + dims=[self.shape[3], self.shape[1], self.shape[2]], + ) assert len(outputs.shape) == 4 assert outputs.dtype == tf.complex64 - assert outputs.shape == self.shape + # assert outputs.shape == mtf.Shape(self.shape) + assert [d.size for d in outputs.shape] == [d.size for d in self.shape] mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(self.graph, {self.mesh: mesh_impl}) outputs_tf = lowering.export_to_tf_tensor(outputs) - tf_tester = tf.test.TestCase() expected_outputs = tf.signal.ifft3d(self.volume) - tf_tester.assertAllEqual( + expected_outputs = tf.transpose(expected_outputs, perm=[0, 3, 1, 2]) + self.assertAllClose( outputs_tf, expected_outputs, + rtol=1e-4, + atol=1e-4, ) From 71517569dc595de03fda82a9d17b1791f0bb97ce Mon Sep 17 00:00:00 2001 From: Zaccharie Ramzi Date: Wed, 25 Nov 2020 17:09:47 +0100 Subject: [PATCH 09/21] slightly complexified the signal ops test --- mesh_tensorflow/signal_ops_test.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/mesh_tensorflow/signal_ops_test.py b/mesh_tensorflow/signal_ops_test.py index 8350dcf1..3edabe3c 100644 --- a/mesh_tensorflow/signal_ops_test.py +++ b/mesh_tensorflow/signal_ops_test.py @@ -10,7 +10,7 @@ def setUp(self): self.mesh = mtf.Mesh(self.graph, "my_mesh") volume_size = 32 batch_dim = mtf.Dimension("batch", 1) - slices_dim = mtf.Dimension("slices", volume_size) + slices_dim = mtf.Dimension("slices", volume_size//2) rows_dim = mtf.Dimension("rows", volume_size) cols_dim = mtf.Dimension("cols", volume_size) self.shape = [batch_dim, slices_dim, rows_dim, cols_dim,] @@ -26,8 +26,7 @@ def testFft3d(self): outputs = fft3d(self.volume_mesh, freq_dims=self.shape[1:4]) assert len(outputs.shape) == 4 assert outputs.dtype == tf.complex64 - # assert outputs.shape == mtf.Shape(self.shape) - assert [d.size for d in outputs.shape] == [d.size for d in self.shape] + assert set(outputs.shape) == set(mtf.Shape(self.shape)) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(self.graph, {self.mesh: mesh_impl}) @@ -49,8 +48,7 @@ def testIfft3d(self): ) assert len(outputs.shape) == 4 assert outputs.dtype == tf.complex64 - # assert outputs.shape == mtf.Shape(self.shape) - assert [d.size for d in outputs.shape] == [d.size for d in self.shape] + assert set(outputs.shape) == set(mtf.Shape(self.shape)) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(self.graph, {self.mesh: mesh_impl}) From 2074197c82dc342ce603fb1ee674defa0fca531a Mon Sep 17 00:00:00 2001 From: Zaccharie Ramzi Date: Wed, 25 Nov 2020 17:47:38 +0100 Subject: [PATCH 10/21] refactored fft3d direct --- mesh_tensorflow/signal_ops.py | 113 ++++++++++++++++++++++------------ 1 file changed, 73 insertions(+), 40 deletions(-) diff --git a/mesh_tensorflow/signal_ops.py b/mesh_tensorflow/signal_ops.py index ebb6cc3b..7febae37 100644 --- a/mesh_tensorflow/signal_ops.py +++ b/mesh_tensorflow/signal_ops.py @@ -10,7 +10,66 @@ from mesh_tensorflow import ops_with_redefined_builtins as mtf -class FFT3DOperation(mtf.Operation): +class FFTBaseOperation(mtf.Operation): + def __init__(self, inputs, dims, inverse=False, name=None): + self.inverse = inverse + if self.inverse: + self.default_name = 'IFFT3D' + self.tf_op = tf.spectral.ifft + else: + self.default_name = 'FFT3D' + self.tf_op = tf.spectral.fft + super(FFTBaseOperation, self).__init__([inputs], name=name or self.default_name) + self._dims = dims + if self.inverse: + dims_reordered = dims + else: + dims_reordered = [dims[1], dims[2], dims[0]] + self._output_shape = mtf.Shape(inputs.shape[:-3]+dims_reordered) + self._outputs = [mtf.Tensor(self, mtf.Shape(self._output_shape), inputs.dtype)] + + def gradient(self, grad_ys): + dy = grad_ys[0] + if self.inverse: + ky, kz, kx = self.inputs[0].shape[-3:] + return [fft3d(dy, [kx, ky, kz])] + else: + x = self.inputs[0] + return [ifft3d(dy, x.shape[-3:])] + + def lower(self, lowering): + mesh_impl = lowering.mesh_impl(self) + x = self.inputs[0] + naxes = len(x.shape) + slices = lowering.tensors[self.inputs[0]] + # Before performing any operations, we check the splitting + split_axes = [] + for i in range(3): + split_axes.append(mesh_impl.tensor_dimension_to_mesh_axis(x.shape.dims[-3:][i])) + + # Perform transform followed by tranposes + for i in range(2): + # Apply FFT along last axis + slices = mesh_impl.slicewise(self.tf_op, slices) + + # Before transposing the array, making sure the new last dimension will + # be contiguous + split_axes, slices = self._make_sure_contiguous( + mesh_impl, + split_axes, + slices, + naxes, + ) + + # Apply transform along last axis + slices = mesh_impl.slicewise(self.tf_op, slices) + lowering.set_tensor_lowering(self.outputs[0], slices) + + def _make_sure_contiguous(self, *args): + raise NotImplementedError('This function needs to be implemented') + + +class FFT3DOperation(FFTBaseOperation): """ Computes the 3-dimensional discrete Fourier transform over the inner-most 3 dimensions of input tensor. Note that the output FFT is transposed. @@ -23,46 +82,20 @@ class FFT3DOperation(mtf.Operation): Returns: A Tensor of shape `input.shape[:-3] + freq_dims`. """ - def __init__(self, input, freq_dims, name=None): - super(FFT3DOperation, self).__init__([input], name=name or "FFT3D") - self._freq_dims = freq_dims - self._output_shape = mtf.Shape(input.shape[:-3]+[freq_dims[1], freq_dims[2], freq_dims[0]]) - self._outputs = [mtf.Tensor(self, mtf.Shape(self._output_shape), input.dtype)] - - def gradient(self, grad_ys): - dy = grad_ys[0] - x = self.inputs[0] - return [ifft3d(dy, x.shape[-3:])] - - def lower(self, lowering): - mesh_impl = lowering.mesh_impl(self) - x = self.inputs[0] - naxes = len(x.shape) - slices = lowering.tensors[self.inputs[0]] - # Before performing any operations, we check the splitting - split_axes = [] - for i in range(3): - split_axes.append(mesh_impl.tensor_dimension_to_mesh_axis(x.shape.dims[-3:][i])) - - # Perform FFT followed by tranposes - for i in range(2): - # Apply FFT along last axis - slices = mesh_impl.slicewise(tf.spectral.fft, slices) - - # Before transposing the array, making sure the new last dimension will - # be contiguous - if split_axes[-2] is not None: - slices = mesh_impl.alltoall(slices, split_axes[-2], naxes-1, naxes-2) - split_axes[-1] = split_axes[-2] - split_axes[-2] = None - perm = np.arange(len(x.shape)) - perm[-3:] = np.roll(perm[-3:], shift=1) - slices = mesh_impl.slicewise(lambda x: tf.transpose(x, perm), slices) - split_axes = [split_axes[2], split_axes[0], split_axes[1]] + def __init__(self, inputs, dims, name=None): + super(FFT3DOperation, self).__init__(inputs, dims, inverse=False, name=name) + + def _make_sure_contiguous(self, mesh_impl, split_axes, slices, naxes): + if split_axes[-2] is not None: + slices = mesh_impl.alltoall(slices, split_axes[-2], naxes-1, naxes-2) + split_axes[-1] = split_axes[-2] + split_axes[-2] = None + perm = np.arange(naxes) + perm[-3:] = np.roll(perm[-3:], shift=1) + slices = mesh_impl.slicewise(lambda x: tf.transpose(x, perm), slices) + split_axes = [split_axes[2], split_axes[0], split_axes[1]] + return split_axes, slices - # Apply FFT along last axis - slices = mesh_impl.slicewise(tf.spectral.fft, slices) - lowering.set_tensor_lowering(self.outputs[0], slices) def fft3d(x, freq_dims, name=None): """ From afe7c63d66d426ca4a6c166b4deb4d9ee55aeaa0 Mon Sep 17 00:00:00 2001 From: Zaccharie Ramzi Date: Wed, 25 Nov 2020 17:51:47 +0100 Subject: [PATCH 11/21] refactored ifft3d --- mesh_tensorflow/signal_ops.py | 56 ++++++++++------------------------- 1 file changed, 15 insertions(+), 41 deletions(-) diff --git a/mesh_tensorflow/signal_ops.py b/mesh_tensorflow/signal_ops.py index 7febae37..05891eb2 100644 --- a/mesh_tensorflow/signal_ops.py +++ b/mesh_tensorflow/signal_ops.py @@ -25,8 +25,8 @@ def __init__(self, inputs, dims, inverse=False, name=None): dims_reordered = dims else: dims_reordered = [dims[1], dims[2], dims[0]] - self._output_shape = mtf.Shape(inputs.shape[:-3]+dims_reordered) - self._outputs = [mtf.Tensor(self, mtf.Shape(self._output_shape), inputs.dtype)] + self._output_shape = mtf.Shape(inputs.shape[:-3]+dims_reordered) + self._outputs = [mtf.Tensor(self, mtf.Shape(self._output_shape), inputs.dtype)] def gradient(self, grad_ys): dy = grad_ys[0] @@ -112,7 +112,7 @@ def fft3d(x, freq_dims, name=None): """ return FFT3DOperation(x, freq_dims, name).outputs[0] -class iFFT3DOperation(mtf.Operation): +class iFFT3DOperation(FFTBaseOperation): """ Computes the inverse 3-dimensional discrete Fourier transform over the inner-most 3 dimensions of input tensor. Note that the input FFT is assumed transposed. @@ -125,46 +125,20 @@ class iFFT3DOperation(mtf.Operation): Returns: A Tensor of shape `input.shape[:-3] + dims`. """ - def __init__(self, input, dims, name=None): - super(iFFT3DOperation, self).__init__([input], name=name or "iFFT3D") - self._dims = dims - self._output_shape = mtf.Shape(input.shape[:-3]+dims) - self._outputs = [mtf.Tensor(self, mtf.Shape(self._output_shape), input.dtype)] + def __init__(self, inputs, dims, name=None): + super(iFFT3DOperation, self).__init__(inputs, dims, inverse=True, name=name) - def gradient(self, grad_ys): - dy = grad_ys[0] - ky, kz, kx = self.inputs[0].shape[-3:] - return [fft3d(dy, [kx, ky, kz])] + def _make_sure_contiguous(self, mesh_impl, split_axes, slices, naxes): + if split_axes[0] is not None: + slices = mesh_impl.alltoall(slices, split_axes[0], naxes-1, naxes-3) + split_axes[-1] = split_axes[0] + split_axes[0] = None + perm = np.arange(naxes) + perm[-3:] = np.roll(perm[-3:], shift=-1) + slices = mesh_impl.slicewise(lambda x: tf.transpose(x, perm), slices) + split_axes = [split_axes[1], split_axes[2], split_axes[0]] + return split_axes, slices - def lower(self, lowering): - mesh_impl = lowering.mesh_impl(self) - x = self.inputs[0] - naxes = len(x.shape) - slices = lowering.tensors[self.inputs[0]] - # Before performing any operations, we check the splitting - split_axes = [] - for i in range(3): - split_axes.append(mesh_impl.tensor_dimension_to_mesh_axis(x.shape.dims[-3:][i])) - - # Perform FFT followed by tranposes - for i in range(2): - # Apply FFT along last axis - slices = mesh_impl.slicewise(tf.spectral.ifft, slices) - - # Before transposing the array, making sure the new last dimension will - # be contiguous - if split_axes[0] is not None: - slices = mesh_impl.alltoall(slices, split_axes[0], naxes-1, naxes-3) - split_axes[-1] = split_axes[0] - split_axes[0] = None - perm = np.arange(len(x.shape)) - perm[-3:] = np.roll(perm[-3:], shift=-1) - slices = mesh_impl.slicewise(lambda x: tf.transpose(x, perm), slices) - split_axes = [split_axes[1], split_axes[2], split_axes[0]] - - # Apply FFT along last axis - slices = mesh_impl.slicewise(tf.spectral.ifft, slices) - lowering.set_tensor_lowering(self.outputs[0], slices) def ifft3d(x, dims, name=None): """ From 9ef47fc72bd276cb5e1af9c419d60d3ab1a6ae53 Mon Sep 17 00:00:00 2001 From: Zaccharie Ramzi Date: Wed, 25 Nov 2020 17:55:55 +0100 Subject: [PATCH 12/21] corrected documentation for shape return of fft3d --- mesh_tensorflow/signal_ops.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/mesh_tensorflow/signal_ops.py b/mesh_tensorflow/signal_ops.py index 05891eb2..e84bd210 100644 --- a/mesh_tensorflow/signal_ops.py +++ b/mesh_tensorflow/signal_ops.py @@ -45,21 +45,21 @@ def lower(self, lowering): # Before performing any operations, we check the splitting split_axes = [] for i in range(3): - split_axes.append(mesh_impl.tensor_dimension_to_mesh_axis(x.shape.dims[-3:][i])) + split_axes.append(mesh_impl.tensor_dimension_to_mesh_axis(x.shape.dims[-3:][i])) # Perform transform followed by tranposes for i in range(2): - # Apply FFT along last axis - slices = mesh_impl.slicewise(self.tf_op, slices) - - # Before transposing the array, making sure the new last dimension will - # be contiguous - split_axes, slices = self._make_sure_contiguous( - mesh_impl, - split_axes, - slices, - naxes, - ) + # Apply FFT along last axis + slices = mesh_impl.slicewise(self.tf_op, slices) + + # Before transposing the array, making sure the new last dimension will + # be contiguous + split_axes, slices = self._make_sure_contiguous( + mesh_impl, + split_axes, + slices, + naxes, + ) # Apply transform along last axis slices = mesh_impl.slicewise(self.tf_op, slices) @@ -80,7 +80,7 @@ class FFT3DOperation(FFTBaseOperation): name: A name for the operation (optional). Returns: - A Tensor of shape `input.shape[:-3] + freq_dims`. + A Tensor of shape `input.shape[:-3] + freq_dims[1] + freq_dims[2] + freq_dims[0]`. """ def __init__(self, inputs, dims, name=None): super(FFT3DOperation, self).__init__(inputs, dims, inverse=False, name=name) From 782250e7f15936e8b8b374c793de191de151e80b Mon Sep 17 00:00:00 2001 From: Zaccharie Ramzi Date: Wed, 25 Nov 2020 17:58:24 +0100 Subject: [PATCH 13/21] changed transpose name --- mesh_tensorflow/signal_ops.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/mesh_tensorflow/signal_ops.py b/mesh_tensorflow/signal_ops.py index e84bd210..8a950791 100644 --- a/mesh_tensorflow/signal_ops.py +++ b/mesh_tensorflow/signal_ops.py @@ -52,9 +52,7 @@ def lower(self, lowering): # Apply FFT along last axis slices = mesh_impl.slicewise(self.tf_op, slices) - # Before transposing the array, making sure the new last dimension will - # be contiguous - split_axes, slices = self._make_sure_contiguous( + split_axes, slices = self._transpose( mesh_impl, split_axes, slices, @@ -65,7 +63,7 @@ def lower(self, lowering): slices = mesh_impl.slicewise(self.tf_op, slices) lowering.set_tensor_lowering(self.outputs[0], slices) - def _make_sure_contiguous(self, *args): + def _transpose(self, *args): raise NotImplementedError('This function needs to be implemented') @@ -85,7 +83,9 @@ class FFT3DOperation(FFTBaseOperation): def __init__(self, inputs, dims, name=None): super(FFT3DOperation, self).__init__(inputs, dims, inverse=False, name=name) - def _make_sure_contiguous(self, mesh_impl, split_axes, slices, naxes): + def _transpose(self, mesh_impl, split_axes, slices, naxes): + # Before transposing the array, making sure the new last dimension will + # be contiguous if split_axes[-2] is not None: slices = mesh_impl.alltoall(slices, split_axes[-2], naxes-1, naxes-2) split_axes[-1] = split_axes[-2] @@ -128,7 +128,9 @@ class iFFT3DOperation(FFTBaseOperation): def __init__(self, inputs, dims, name=None): super(iFFT3DOperation, self).__init__(inputs, dims, inverse=True, name=name) - def _make_sure_contiguous(self, mesh_impl, split_axes, slices, naxes): + def _transpose(self, mesh_impl, split_axes, slices, naxes): + # Before transposing the array, making sure the new last dimension will + # be contiguous if split_axes[0] is not None: slices = mesh_impl.alltoall(slices, split_axes[0], naxes-1, naxes-3) split_axes[-1] = split_axes[0] From 7ea6a9036f776f4f0c1c3f6a1a0918657fca2e46 Mon Sep 17 00:00:00 2001 From: Zaccharie Ramzi Date: Thu, 26 Nov 2020 19:32:31 +0100 Subject: [PATCH 14/21] simplified the implementation of the split complex op --- mesh_tensorflow/ops.py | 60 ++++++++++++++---------------------------- 1 file changed, 20 insertions(+), 40 deletions(-) diff --git a/mesh_tensorflow/ops.py b/mesh_tensorflow/ops.py index 1df39936..1bd6671a 100644 --- a/mesh_tensorflow/ops.py +++ b/mesh_tensorflow/ops.py @@ -6630,44 +6630,24 @@ def split_complex(x, complex_dim=None): Returns: a Tensor, float-valued """ - op = SplitComplexOperation(x, complex_dim=complex_dim) - output = op.outputs[0] + if complex_dim is None: + split_dim = x.shape.dims[-1] + split_axis = -1 + else: + split_dim = complex_dim + split_axis = x.shape.index(complex_dim) + splittable_dims = [d for d in x.shape if d != split_dim] + def tf_fn(tf_input): + tf_real = tf.math.real(tf_input) + tf_imag = tf.math.imag(tf_input) + output = tf.concat([tf_real, tf_imag], axis=split_axis) + return output + output = slicewise( + tf_fn, + x, + output_shape=x.shape.resize_dimension(split_dim.name, split_dim.size*2), + output_dtype=tf.float32, + splittable_dims=splittable_dims, + name='split_complex', + ) return output - -class SplitComplexOperation(Operation): - def __init__(self, split_input, complex_dim=None, name=None): - super().__init__([split_input], name=name or 'split_complex') - if complex_dim is None: - self._split_dim = split_input.shape.dims[-1] - self._split_axis = -1 - else: - self._split_dim = complex_dim - self._split_axis = split_input.shape.index(complex_dim) - self._splittable_dims, self._unsplittable_dims = ( - self._initialize_splittable_and_unsplittable_dims( - "splittable", [self._split_dim.name], - ) - ) - output_shape = split_input.shape.resize_dimension( - self._split_dim.name, - self._split_dim.size*2, - ) - self._outputs = [Tensor(self, output_shape, tf.float32)] - - def gradient(self, grad_ys): - dy = grad_ys[0] - dy_complex = to_complex(dy) - return [dy_complex] - - def lower(self, lowering): - mesh_impl = lowering.mesh_impl(self) - split_input = self.inputs[0] - if mesh_impl.tensor_dimension_to_mesh_axis(self._split_dim) is not None: - raise ValueError("can't slice along complex split dimension") - def tf_fn(tf_input): - tf_real = tf.math.real(tf_input) - tf_imag = tf.math.imag(tf_input) - output = tf.concat([tf_real, tf_imag], axis=self._split_axis) - return output - y = mesh_impl.slicewise(tf_fn, lowering.tensors[split_input]) - lowering.set_tensor_lowering(self.outputs[0], y) From bcd791e24c485f527a19e3f2d17d21d46872a137 Mon Sep 17 00:00:00 2001 From: Zaccharie Ramzi Date: Thu, 26 Nov 2020 19:34:22 +0100 Subject: [PATCH 15/21] corrected slicewise input list expected --- mesh_tensorflow/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mesh_tensorflow/ops.py b/mesh_tensorflow/ops.py index 1bd6671a..500bfeb9 100644 --- a/mesh_tensorflow/ops.py +++ b/mesh_tensorflow/ops.py @@ -6644,7 +6644,7 @@ def tf_fn(tf_input): return output output = slicewise( tf_fn, - x, + [x], output_shape=x.shape.resize_dimension(split_dim.name, split_dim.size*2), output_dtype=tf.float32, splittable_dims=splittable_dims, From ea9b55f607b971381430a14e1c8f2d400c699d24 Mon Sep 17 00:00:00 2001 From: Zaccharie Ramzi Date: Thu, 26 Nov 2020 19:35:49 +0100 Subject: [PATCH 16/21] corrected name for the base op --- mesh_tensorflow/signal_ops.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mesh_tensorflow/signal_ops.py b/mesh_tensorflow/signal_ops.py index 8a950791..2784719f 100644 --- a/mesh_tensorflow/signal_ops.py +++ b/mesh_tensorflow/signal_ops.py @@ -10,7 +10,7 @@ from mesh_tensorflow import ops_with_redefined_builtins as mtf -class FFTBaseOperation(mtf.Operation): +class FFT3DBaseOperation(mtf.Operation): def __init__(self, inputs, dims, inverse=False, name=None): self.inverse = inverse if self.inverse: @@ -67,7 +67,7 @@ def _transpose(self, *args): raise NotImplementedError('This function needs to be implemented') -class FFT3DOperation(FFTBaseOperation): +class FFT3DOperation(FFT3DBaseOperation): """ Computes the 3-dimensional discrete Fourier transform over the inner-most 3 dimensions of input tensor. Note that the output FFT is transposed. @@ -112,7 +112,7 @@ def fft3d(x, freq_dims, name=None): """ return FFT3DOperation(x, freq_dims, name).outputs[0] -class iFFT3DOperation(FFTBaseOperation): +class iFFT3DOperation(FFT3DBaseOperation): """ Computes the inverse 3-dimensional discrete Fourier transform over the inner-most 3 dimensions of input tensor. Note that the input FFT is assumed transposed. From ce9e850400b5e070dab7a95482ade8ca6072a242 Mon Sep 17 00:00:00 2001 From: Zaccharie Ramzi Date: Thu, 26 Nov 2020 19:36:15 +0100 Subject: [PATCH 17/21] corrected doc for fft3d function --- mesh_tensorflow/signal_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mesh_tensorflow/signal_ops.py b/mesh_tensorflow/signal_ops.py index 2784719f..ba4894c1 100644 --- a/mesh_tensorflow/signal_ops.py +++ b/mesh_tensorflow/signal_ops.py @@ -108,7 +108,7 @@ def fft3d(x, freq_dims, name=None): name: A name for the operation (optional). Returns: - A Tensor of shape `input.shape[:-3] + freq_dims`. + A Tensor of shape `input.shape[:-3] + freq_dims[1] + freq_dims[2] + freq_dims[0]`. """ return FFT3DOperation(x, freq_dims, name).outputs[0] From f6a2dfe6f9d79084e651f203b01fca72327479e6 Mon Sep 17 00:00:00 2001 From: Zaccharie Ramzi Date: Thu, 26 Nov 2020 19:38:24 +0100 Subject: [PATCH 18/21] corrected legacy typo in fft base class --- mesh_tensorflow/signal_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mesh_tensorflow/signal_ops.py b/mesh_tensorflow/signal_ops.py index ba4894c1..2a94ae86 100644 --- a/mesh_tensorflow/signal_ops.py +++ b/mesh_tensorflow/signal_ops.py @@ -19,7 +19,7 @@ def __init__(self, inputs, dims, inverse=False, name=None): else: self.default_name = 'FFT3D' self.tf_op = tf.spectral.fft - super(FFTBaseOperation, self).__init__([inputs], name=name or self.default_name) + super(FFT3DBaseOperation, self).__init__([inputs], name=name or self.default_name) self._dims = dims if self.inverse: dims_reordered = dims From 497c2575d862acbb4dd44f56bdc129ee430baf91 Mon Sep 17 00:00:00 2001 From: Zaccharie Ramzi Date: Thu, 26 Nov 2020 19:38:43 +0100 Subject: [PATCH 19/21] corrected order of transpose in fft test to compare the output directly --- mesh_tensorflow/signal_ops_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mesh_tensorflow/signal_ops_test.py b/mesh_tensorflow/signal_ops_test.py index 3edabe3c..a6f75f10 100644 --- a/mesh_tensorflow/signal_ops_test.py +++ b/mesh_tensorflow/signal_ops_test.py @@ -53,8 +53,8 @@ def testIfft3d(self): shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(self.graph, {self.mesh: mesh_impl}) outputs_tf = lowering.export_to_tf_tensor(outputs) - expected_outputs = tf.signal.ifft3d(self.volume) - expected_outputs = tf.transpose(expected_outputs, perm=[0, 3, 1, 2]) + volume = tf.transpose(self.volume, perm=[0, 3, 1, 2]) + expected_outputs = tf.signal.ifft3d(volume) self.assertAllClose( outputs_tf, expected_outputs, From 28debd58c08b6b31e21dfdd8a7236dddad61bdd1 Mon Sep 17 00:00:00 2001 From: EiffL Date: Thu, 26 Nov 2020 20:29:18 +0100 Subject: [PATCH 20/21] Proposes some further code compression to remove duplication of transpose ops --- mesh_tensorflow/signal_ops.py | 88 +++++++++++------------------------ 1 file changed, 26 insertions(+), 62 deletions(-) diff --git a/mesh_tensorflow/signal_ops.py b/mesh_tensorflow/signal_ops.py index 2a94ae86..9d8d58a2 100644 --- a/mesh_tensorflow/signal_ops.py +++ b/mesh_tensorflow/signal_ops.py @@ -11,6 +11,13 @@ class FFT3DBaseOperation(mtf.Operation): + """ Base class for performing distributed FFTs. + + Handles slicewise ffts and array transpositions. Note that to save one global + transposition at the end of forward and inverse FFTs, these operations + assume a transposed fourier space with shape: + input.shape[:-3] + freq_dims[1] + freq_dims[2] + freq_dims[0] + """ def __init__(self, inputs, dims, inverse=False, name=None): self.inverse = inverse if self.inverse: @@ -63,39 +70,26 @@ def lower(self, lowering): slices = mesh_impl.slicewise(self.tf_op, slices) lowering.set_tensor_lowering(self.outputs[0], slices) - def _transpose(self, *args): - raise NotImplementedError('This function needs to be implemented') - - -class FFT3DOperation(FFT3DBaseOperation): - """ - Computes the 3-dimensional discrete Fourier transform over the inner-most 3 - dimensions of input tensor. Note that the output FFT is transposed. - - Args: - input: A Tensor. Must be one of the following types: complex64, complex128 - freq_dims: List of 3 Dimensions representing the frequency dimensions. - name: A name for the operation (optional). - - Returns: - A Tensor of shape `input.shape[:-3] + freq_dims[1] + freq_dims[2] + freq_dims[0]`. - """ - def __init__(self, inputs, dims, name=None): - super(FFT3DOperation, self).__init__(inputs, dims, inverse=False, name=name) - def _transpose(self, mesh_impl, split_axes, slices, naxes): - # Before transposing the array, making sure the new last dimension will - # be contiguous + # Before transposing the array, making sure the new last dimension will + # be contiguous + if self.inverse: + if split_axes[0] is not None: + slices = mesh_impl.alltoall(slices, split_axes[0], naxes-1, naxes-3) + split_axes[-1] = split_axes[0] + split_axes[0] = None + split_axes = [split_axes[1], split_axes[2], split_axes[0]] + else: if split_axes[-2] is not None: - slices = mesh_impl.alltoall(slices, split_axes[-2], naxes-1, naxes-2) - split_axes[-1] = split_axes[-2] - split_axes[-2] = None - perm = np.arange(naxes) - perm[-3:] = np.roll(perm[-3:], shift=1) - slices = mesh_impl.slicewise(lambda x: tf.transpose(x, perm), slices) - split_axes = [split_axes[2], split_axes[0], split_axes[1]] - return split_axes, slices + slices = mesh_impl.alltoall(slices, split_axes[-2], naxes-1, naxes-2) + split_axes[-1] = split_axes[-2] + split_axes[-2] = None + split_axes = [split_axes[2], split_axes[0], split_axes[1]] + perm = np.arange(naxes) + perm[-3:] = np.roll(perm[-3:], shift=-1) + slices = mesh_impl.slicewise(lambda x: tf.transpose(x, perm), slices) + return split_axes, slices def fft3d(x, freq_dims, name=None): """ @@ -110,37 +104,7 @@ def fft3d(x, freq_dims, name=None): Returns: A Tensor of shape `input.shape[:-3] + freq_dims[1] + freq_dims[2] + freq_dims[0]`. """ - return FFT3DOperation(x, freq_dims, name).outputs[0] - -class iFFT3DOperation(FFT3DBaseOperation): - """ - Computes the inverse 3-dimensional discrete Fourier transform over the inner-most 3 - dimensions of input tensor. Note that the input FFT is assumed transposed. - - Args: - input: A Tensor. Must be one of the following types: complex64, complex128 - dims: List of 3 Dimensions representing the direct space dimensions. - name: A name for the operation (optional). - - Returns: - A Tensor of shape `input.shape[:-3] + dims`. - """ - def __init__(self, inputs, dims, name=None): - super(iFFT3DOperation, self).__init__(inputs, dims, inverse=True, name=name) - - def _transpose(self, mesh_impl, split_axes, slices, naxes): - # Before transposing the array, making sure the new last dimension will - # be contiguous - if split_axes[0] is not None: - slices = mesh_impl.alltoall(slices, split_axes[0], naxes-1, naxes-3) - split_axes[-1] = split_axes[0] - split_axes[0] = None - perm = np.arange(naxes) - perm[-3:] = np.roll(perm[-3:], shift=-1) - slices = mesh_impl.slicewise(lambda x: tf.transpose(x, perm), slices) - split_axes = [split_axes[1], split_axes[2], split_axes[0]] - return split_axes, slices - + return FFT3DBaseOperation(x, freq_dims, inverse=False, name=name).outputs[0] def ifft3d(x, dims, name=None): """ @@ -155,4 +119,4 @@ def ifft3d(x, dims, name=None): Returns: A Tensor of shape `input.shape[:-3] + dims`. """ - return iFFT3DOperation(x, dims, name).outputs[0] + return FFT3DBaseOperation(x, freq_dims, inverse=True, name=name).outputs[0] From daab646b24f0f28ebb58fe995ea6077e6cfc34df Mon Sep 17 00:00:00 2001 From: EiffL Date: Thu, 26 Nov 2020 21:00:13 +0100 Subject: [PATCH 21/21] fix bug introduced by previous commit --- mesh_tensorflow/signal_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mesh_tensorflow/signal_ops.py b/mesh_tensorflow/signal_ops.py index 9d8d58a2..27fdc5f3 100644 --- a/mesh_tensorflow/signal_ops.py +++ b/mesh_tensorflow/signal_ops.py @@ -87,7 +87,7 @@ def _transpose(self, mesh_impl, split_axes, slices, naxes): split_axes = [split_axes[2], split_axes[0], split_axes[1]] perm = np.arange(naxes) - perm[-3:] = np.roll(perm[-3:], shift=-1) + perm[-3:] = np.roll(perm[-3:], shift=-1 if self.inverse else 1) slices = mesh_impl.slicewise(lambda x: tf.transpose(x, perm), slices) return split_axes, slices @@ -119,4 +119,4 @@ def ifft3d(x, dims, name=None): Returns: A Tensor of shape `input.shape[:-3] + dims`. """ - return FFT3DBaseOperation(x, freq_dims, inverse=True, name=name).outputs[0] + return FFT3DBaseOperation(x, dims, inverse=True, name=name).outputs[0]