From f1502f04740ed2c544284e039fd98a4784e562d7 Mon Sep 17 00:00:00 2001 From: Mesh TensorFlow Team Date: Fri, 22 Jan 2021 12:10:36 -0800 Subject: [PATCH] internal PiperOrigin-RevId: 353292528 --- examples/toy_model_tpu.py | 292 -------------------------------------- 1 file changed, 292 deletions(-) delete mode 100644 examples/toy_model_tpu.py diff --git a/examples/toy_model_tpu.py b/examples/toy_model_tpu.py deleted file mode 100644 index 9d836941..00000000 --- a/examples/toy_model_tpu.py +++ /dev/null @@ -1,292 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The Mesh TensorFlow Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""A toy model using Mesh TensorFlow.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import mesh_tensorflow as mtf -import numpy -import tensorflow.compat.v1 as tf - -from tensorflow.python.data.ops.dataset_ops import Dataset -from tensorflow.python.platform import flags -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.tpu import tpu_config # pylint: disable=g-direct-tensorflow-import -from tensorflow.python.tpu import tpu_estimator # pylint: disable=g-direct-tensorflow-import -from tensorflow_estimator.python.estimator import estimator as estimator_lib - -FLAGS = flags.FLAGS - -tf.flags.DEFINE_integer('batch_size', 64, 'Training batch size.') -tf.flags.DEFINE_integer('io_size', 16, 'Number of channels per feature.') -tf.flags.DEFINE_integer('hidden_size', 16, 'Size of each hidden layer.') -tf.flags.DEFINE_integer('num_hidden_layers', 1, 'Number of layers.') -tf.flags.DEFINE_string('master_dtype', 'bfloat16', 'dtype for master vars.') -tf.flags.DEFINE_string('slice_dtype', 'float32', 'dtype for slice vars.') -tf.flags.DEFINE_string('activation_dtype', 'float32', 'dtype for activations.') -tf.flags.DEFINE_string('optimizer', 'SGD', 'optimizer (SGD or Adafactor).') -tf.flags.DEFINE_float('lr', 1e-4, 'Learning rate.') -tf.flags.DEFINE_string('mesh_shape', 'all:8', 'mesh shape') -tf.flags.DEFINE_string('layout', 'hidden_odd:all', 'layout rules') -tf.flags.DEFINE_integer('iterations', 100, - 'Number of iterations per training loop.') -tf.flags.DEFINE_integer('step_with_nan', -1, - 'If >= 0, a NaN tensor is added in forward pass.') -tf.flags.DEFINE_integer('train_steps', 10000, 'max steps') -tf.flags.DEFINE_integer('steps_per_checkpoint', 200, 'steps_per_checkpoint') -tf.flags.DEFINE_string( - 'model_dir', - default='', - help='The directory where the model will be stored.') -tf.flags.DEFINE_bool('use_tpu', True, 'use TPU') - -# Cloud TPU Cluster Resolvers -tf.flags.DEFINE_string( - 'tpu', - default=None, - help='The Cloud TPU to use for training. This should be either the name ' - 'used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 url.') - -tf.flags.DEFINE_string( - 'gcp_project', - default=None, - help='Project name for the Cloud TPU-enabled project. If not specified, we ' - 'will attempt to automatically detect the GCE project from metadata.') - -tf.flags.DEFINE_string( - 'tpu_zone', - default=None, - help='GCE zone where the Cloud TPU is located in. If not specified, we ' - 'will attempt to automatically detect the GCE project from metadata.') - - -class ToyModelInput(object): - """Wrapper class that acts as the input_fn to TPUEstimator.""" - - def __init__(self): - self._num_examples = 10000 # 10k - self._images = numpy.random.uniform( - 0, 1.0, [self._num_examples, FLAGS.io_size]).astype(numpy.float32) - self._labels = self._images - logging.info('init ToyModelInput()') - - def __call__(self, params): - """Input function which provides a single batch for train or eval.""" - # Retrieves the batch size for the current shard. The # of shards is - # computed according to the input pipeline deployment. See - # `tf.estimator.tpu.RunConfig` for details. - batch_size = params['batch_size'] - logging.info('call ToyModelInput() with batch size {}'.format(batch_size)) - - ds = Dataset.from_tensor_slices((self._images, self._labels)).repeat() - - dataset = ds.batch(batch_size, drop_remainder=True).prefetch(2) - - return dataset - - -def toy_model(features, mesh): - """A toy model implemented by mesh tensorlfow.""" - batch_dim = mtf.Dimension('batch', FLAGS.batch_size) - io_dim = mtf.Dimension('io', FLAGS.io_size) - - master_dtype = tf.as_dtype(FLAGS.master_dtype) - slice_dtype = tf.as_dtype(FLAGS.slice_dtype) - activation_dtype = tf.as_dtype(FLAGS.activation_dtype) - - x = mtf.import_tf_tensor(mesh, features, mtf.Shape([batch_dim, io_dim])) - x = mtf.cast(x, activation_dtype) - h = x - for lnum in range(1, FLAGS.num_hidden_layers + 2): - if lnum + 1 == FLAGS.num_hidden_layers + 2: - # output layer - dim = io_dim - elif lnum % 2 == 0: - dim = mtf.Dimension('hidden_even', FLAGS.hidden_size) - else: - dim = mtf.Dimension('hidden_odd', FLAGS.hidden_size) - h = mtf.layers.dense( - h, dim, - use_bias=False, - master_dtype=master_dtype, - slice_dtype=slice_dtype, - name='layer_%d' % lnum) - y = h - g = tf.train.get_global_step() - if FLAGS.step_with_nan >= 0: - # Trigger NaN in the forward pass, this is used for testing whether - # MeshTensorFlow can handle occasional NaN value. - y += mtf.import_tf_tensor( - mesh, - tf.divide( - 0.0, - tf.cond(tf.equal(g, FLAGS.step_with_nan), lambda: 0., lambda: 1.)), - mtf.Shape([])) - - loss = mtf.reduce_mean(mtf.square(y - x)) - return y, loss - - -def model_fn(features, labels, mode, params): - """A model is called by TpuEstimator.""" - del labels - global_step = tf.train.get_global_step() - graph = mtf.Graph() - mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) - layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) - if FLAGS.use_tpu: - ctx = params['context'] - num_hosts = ctx.num_hosts - host_placement_fn = ctx.tpu_host_placement_function - device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)] - tf.logging.info('device_list = %s' % device_list,) - # TODO(ylc): Better estimation of replica cache size? - replica_cache_size = 300 * 1000000 # 300M per replica - # Worker 0 caches all the TPU binaries. - worker0_mem = replica_cache_size * ctx.num_replicas - devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) - var_placer = mtf.utils.BalancedVariablePlacer(device_list, - devices_memeory_usage) - mesh_devices = [''] * mesh_shape.size - mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( - mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) - else: - var_placer = None - mesh_devices = [''] * mesh_shape.size - mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( - mesh_shape, layout_rules, mesh_devices) - mesh = mtf.Mesh(graph, 'my_mesh', var_placer) - - with mtf.utils.outside_all_rewrites(): - logits, loss = toy_model(features, mesh) - - # TRAIN mode - if mode == tf.estimator.ModeKeys.TRAIN: - var_grads = mtf.gradients([loss], - [v.outputs[0] for v in graph.trainable_variables]) - if FLAGS.optimizer == 'Adafactor': - optimizer = mtf.optimize.AdafactorOptimizer() - else: - assert FLAGS.optimizer == 'SGD' - optimizer = mtf.optimize.SgdOptimizer(learning_rate=FLAGS.lr) - update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) - else: - # for now, we can only export fully-replicated tensors. - fully_replicated_logits = mtf.anonymize(logits) - - lowering = mtf.Lowering(graph, {mesh: mesh_impl}) - - tf_loss = tf.to_float(lowering.export_to_tf_tensor(loss)) - - if mode == tf.estimator.ModeKeys.TRAIN: - tf_update_ops = [lowering.lowered_operation(op) for op in update_ops] - tf_update_ops.append(tf.assign_add(global_step, 1)) - tf.logging.info('tf_update_ops: {}'.format(tf_update_ops)) - train_op = tf.group(tf_update_ops) - else: - tf_logits = lowering.export_to_tf_tensor(fully_replicated_logits) - - with mtf.utils.outside_all_rewrites(): - # Copy master variables to slices. Must be called first. - restore_hook = mtf.MtfRestoreHook(lowering) - if mode == tf.estimator.ModeKeys.TRAIN: - saver = tf.train.Saver( - tf.global_variables(), - sharded=True, - max_to_keep=10, - keep_checkpoint_every_n_hours=2, - defer_build=False, - save_relative_paths=True) - tf.add_to_collection(tf.GraphKeys.SAVERS, saver) - saver_listener = mtf.MtfCheckpointSaverListener(lowering) - saver_hook = tf.train.CheckpointSaverHook( - FLAGS.model_dir, - save_steps=1000, - saver=saver, - listeners=[saver_listener]) - - return tpu_estimator.TPUEstimatorSpec( - tf.estimator.ModeKeys.TRAIN, - loss=tf_loss, - train_op=train_op, - training_hooks=[restore_hook, saver_hook]) - elif mode == tf.estimator.ModeKeys.EVAL: - - def metric_fn(tf_logits): - mean_logits = tf.metrics.mean(tf_logits) - return {'mean_logits': mean_logits} - - eval_metrics = (metric_fn, [tf_logits]) - - return tpu_estimator.TPUEstimatorSpec( - tf.estimator.ModeKeys.EVAL, - evaluation_hooks=[restore_hook], - loss=tf_loss, - eval_metrics=eval_metrics) - - -def run_toy_model_tpu(): - """Run a toy model on TPU.""" - tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( - FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) - - iterations_per_loop = FLAGS.iterations - mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) - config = tpu_config.RunConfig( - cluster=tpu_cluster_resolver, - model_dir=FLAGS.model_dir, - save_checkpoints_steps=None, # Disable the default saver - save_checkpoints_secs=None, # Disable the default saver - log_step_count_steps=iterations_per_loop, - save_summary_steps=iterations_per_loop, - tpu_config=tpu_config.TPUConfig( - num_shards=mesh_shape.size, - iterations_per_loop=iterations_per_loop, - num_cores_per_replica=1, - per_host_input_for_training=tpu_config.InputPipelineConfig.BROADCAST)) - classifier = tpu_estimator.TPUEstimator( - use_tpu=True, - model_fn=model_fn, - config=config, - train_batch_size=FLAGS.batch_size, - eval_batch_size=FLAGS.batch_size) - current_step = estimator_lib._load_global_step_from_checkpoint_dir(FLAGS.model_dir) # pylint: disable=protected-access,line-too-long - logging.info('Current step %d', current_step) - if FLAGS.steps_per_checkpoint == 0: - classifier.train(input_fn=ToyModelInput(), max_steps=FLAGS.train_steps) - return - while current_step < FLAGS.train_steps: - next_checkpoint = min(current_step + FLAGS.steps_per_checkpoint, - FLAGS.train_steps) - classifier.train(input_fn=ToyModelInput(), max_steps=next_checkpoint) - current_step = next_checkpoint - logging.info('Starting to evaluate.') - eval_results = classifier.evaluate( - input_fn=ToyModelInput(), - steps=156) # since we have 10000 examples and batch_size = 64 per host - logging.info('Eval results: %s', eval_results) - - -def main(_): - run_toy_model_tpu() - - -if __name__ == '__main__': - tf.disable_v2_behavior() - tf.logging.set_verbosity(tf.logging.INFO) - tf.app.run()