From 355a560513431ddf37a5e9745964527bb0555174 Mon Sep 17 00:00:00 2001 From: Mesh TensorFlow Team Date: Fri, 16 Jul 2021 11:31:20 -0700 Subject: [PATCH] MODE models with hetereogeneous expert width PiperOrigin-RevId: 385189162 --- mesh_tensorflow/transformer/moe.py | 222 ++++++++++++++++++++--------- 1 file changed, 157 insertions(+), 65 deletions(-) diff --git a/mesh_tensorflow/transformer/moe.py b/mesh_tensorflow/transformer/moe.py index 82fe9455..1679f34c 100644 --- a/mesh_tensorflow/transformer/moe.py +++ b/mesh_tensorflow/transformer/moe.py @@ -21,6 +21,7 @@ TODO(noam): Remove the other copy of this code from tensor2tensor. TODO(noam): Write a new, simpler, cleaner version of this code. """ + from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -29,10 +30,9 @@ import mesh_tensorflow as mtf from mesh_tensorflow.transformer import transformer - +import numpy as np import tensorflow.compat.v1 as tf - @gin.configurable class MoE1D(transformer.TransformerLayer): """Mixture of Experts Layer.""" @@ -64,7 +64,9 @@ def __init__(self, z_loss=None, word_embed_mode=None, use_second_place_expert_prob=None, - use_second_place_expert_prob_temp=None): + use_second_place_expert_prob_temp=None, + num_layers=1, + heterogeneous_mask_info=None): self._hparams = HParams( moe_gating=moe_gating, moe_num_experts=num_experts, @@ -93,7 +95,9 @@ def __init__(self, moe_use_second_place_expert_prob=( use_second_place_expert_prob), moe_use_second_place_expert_prob_temp=( - use_second_place_expert_prob_temp)) + use_second_place_expert_prob_temp), + moe_num_layers=num_layers, + moe_heterogeneous_mask_info=heterogeneous_mask_info) self._activation = activation def call(self, context, x, losses=None): @@ -125,7 +129,8 @@ def call(self, context, x, losses=None): nonpadding=context.nonpadding, activation=self._activation, num_microbatches=context.num_microbatches, - token_embeddings=context.input_embeddings) + token_embeddings=context.input_embeddings, + context=context) if context.losses is not None: context.losses.append(loss) if not has_length_dim: @@ -200,7 +205,7 @@ def call(self, context, x, losses=None): def transformer_moe_layer_v1( inputs, output_dim, hparams, train, variable_dtype, layout=None, mesh_shape=None, nonpadding=None, activation=mtf.relu, - num_microbatches=None, token_embeddings=None): + num_microbatches=None, token_embeddings=None, context=None): """Local mixture of experts that works well on TPU. Adapted from the paper https://arxiv.org/abs/1701.06538 @@ -279,6 +284,7 @@ def transformer_moe_layer_v1( [batch_dim(s), length_dim, input_dim]. These are the word embeddings for that correspond to the inputs. These can optionally be used to make routing decisions. + context: a Context. Returns: outputs: a Tensor with shape [batch_dim(s), length_dim, output_dim] @@ -334,9 +340,24 @@ def transformer_moe_layer_v1( # # pylint: enable=line-too-long orig_inputs = inputs - hidden_dim = mtf.Dimension("expert_hidden", hparams.moe_hidden_size) + experts_dim = mtf.Dimension("experts", hparams.moe_num_experts) + if hparams.moe_heterogeneous_mask_info is not None: + tf.logging.info("moe_heterogeneous_mask_info: {}".format( + hparams.moe_heterogeneous_mask_info)) + heterogeneous_mask = generate_heterogeneous_expert_masks( + hparams.moe_heterogeneous_mask_info, + hparams.moe_num_experts, + experts_dim, + mesh=inputs.mesh) + # overwrite num_layers and width with the mask dimension + #TODO(chaimerl) depending on whether function output is flattened or not + # this might need adjustment + hparams.moe_num_layers = heterogeneous_mask.shape[1].size + hparams.moe_hidden_size = heterogeneous_mask.shape[0].size + hidden_dim = mtf.Dimension("expert_hidden", hparams.moe_hidden_size) + # We "cheat" here and look at the mesh shape and layout. This is to ensure # that the number of groups is a multiple of the mesh dimension # over which those groups are split. @@ -489,64 +510,81 @@ def transformer_moe_layer_v1( input_dim ])) - # Now feed the expert inputs through the experts. - h = mtf.layers.dense_product( - expert_inputs, - reduced_dims=expert_inputs.shape.dims[-1:], - new_dims=[hidden_dim], - expert_dims=[experts_dim], - activation_functions=activation, use_bias=False, - variable_dtype=variable_dtype, name="wi") - - if hparams.moe_dropout_rate != 0.0: - h = mtf.dropout(h, is_training=train, - keep_prob=1.0 - hparams.moe_dropout_rate) - - def _compute_output(hidden, layer_name): - """Compute the output of the attention layer from the hidden vector.""" - expert_output = mtf.layers.dense( - hidden, output_dim, expert_dims=[experts_dim], use_bias=False, - reduced_dims=hidden.shape.dims[-1:], variable_dtype=variable_dtype, - name=layer_name) - - # Extra reshape reduces communication cost for model-parallel versions. - # For model-parallel versions, this reshape causes an mtf.slice and for non- - # model-parallel versions, this has no effect. - expert_output = mtf.reshape( - expert_output, - mtf.Shape([ - outer_batch_dim, experts_dim_unsplit, num_groups_dim, - expert_capacity_dim, d_model_split_dim - ])) - - # Split over experts -> split over batch - expert_output = mtf.reshape( - expert_output, - mtf.Shape([ - outer_batch_dim, - experts_dim_unsplit, - num_groups_dim, - expert_capacity_dim, - output_dim, - ])) - moe_output_dims = moe_input_dims[:-1] + [output_dim] - output = mtf.einsum([expert_output, combine_tensor], - mtf.Shape(moe_output_dims)) - output = mtf.reshape(output, batch_and_length_dims + [output_dim]) - return output - - if hparams.moe_use_experts_attention: - # We share k_h and v_h with no degradation in performance - q_h, k_h = h, h - outputs = [] - q = _compute_output(q_h, layer_name="q_wo") - k = _compute_output(k_h, layer_name="k_wo") - outputs.append(q) - outputs.append(k) - return outputs, loss * hparams.moe_loss_coef - else: - output = _compute_output(h, layer_name="wo") - return output, loss * hparams.moe_loss_coef + # Pretend we have heterogenous_mask with shape [num_layers, num_experts] + for layer in range(hparams.moe_num_layers): + with tf.variable_scope("expert_layer_{}".format(layer)): + res_h = 0.0 + if layer > 0: + res_h = expert_inputs + expert_inputs = transformer.sublayer_rms_norm( + expert_inputs, None, context) + + # Now feed the expert inputs through the experts. + h = mtf.layers.dense_product( + expert_inputs, + reduced_dims=expert_inputs.shape.dims[-1:], + new_dims=[hidden_dim], + expert_dims=[experts_dim], + activation_functions=activation, use_bias=False, + variable_dtype=variable_dtype, name="wi") + + # apply dropout + if hparams.moe_dropout_rate != 0.0: + h = mtf.dropout(h, is_training=train, + keep_prob=1.0 - hparams.moe_dropout_rate) + #h = mtf.Print(h, [h], 'values of hidden activity before: ') + # only if heterogeneous + if hparams.moe_heterogeneous_mask_info is not None: + # Apply mask. + # TODO(chaimerl): change to include width --> needs to be applied + # within the expert --> h + heterogeneous_mask_slice = mtf.slice( + heterogeneous_mask, layer, 1, heterogeneous_mask.shape[1].name) + + # Get rid of the expert layers dimension. + heterogeneous_mask_slice = mtf.reshape(heterogeneous_mask_slice, + [heterogeneous_mask_slice.shape[0], + heterogeneous_mask_slice.shape[-1]]) + h *= mtf.cast(heterogeneous_mask_slice, h.dtype) + # h = mtf.Print(h, [h], 'values of hidden activity after: ') + # Q: what happens here? why going from expert_hidden dim to d_model dim + expert_output = mtf.layers.dense( + h, output_dim, expert_dims=[experts_dim], use_bias=False, + reduced_dims=h.shape.dims[-1:], variable_dtype=variable_dtype, + name="wo") + + if layer < (hparams.moe_num_layers - 1): + expert_output = transformer.sublayer_dropout( + expert_output, None, context) + # import pdb; pdb.set_trace() + expert_output += res_h + expert_inputs = expert_output + + # Extra reshape reduces communication cost for model-parallel versions. + # For model-parallel versions, this reshape causes an mtf.slice and for non- + # model-parallel versions, this has no effect. + expert_output = mtf.reshape( + expert_output, + mtf.Shape([ + outer_batch_dim, experts_dim_unsplit, num_groups_dim, + expert_capacity_dim, d_model_split_dim + ])) + + # Split over experts -> split over batch + expert_output = mtf.reshape( + expert_output, + mtf.Shape([ + outer_batch_dim, + experts_dim_unsplit, + num_groups_dim, + expert_capacity_dim, + output_dim, + ])) + moe_output_dims = moe_input_dims[:-1] + [output_dim] + output = mtf.einsum([expert_output, combine_tensor], + mtf.Shape(moe_output_dims)) + output = mtf.reshape(output, batch_and_length_dims + [output_dim]) + return output, loss * hparams.moe_loss_coef def transformer_moe_layer_v2( @@ -1720,3 +1758,57 @@ def __init__(self, **kwargs): def add_hparam(self, k, v): setattr(self, k, v) + + +def generate_heterogeneous_expert_masks( + mask_info, num_experts, experts_dim, mesh, default_width=256): + """Returns mask of shape [num_layers, num_experts, hidden_size]. + + # mask_info + # num_experts: number of experts in the model + # experts_dim: mtf dimension for experts (partitioned) + # mesh: mesh object + # + # Example mask_info format: + # mask_info = [{'percent_number': .5, 'layers': 1, 'width':1}, + # {'percent_number': .5, 'layers': 2, 'width':2}] + """ + # Get max num layers + max_i = 0 + max_layers = [max(max_i, mask_i["layers"]) for mask_i in mask_info][-1] + # Get max width + max_width = [max(max_i, mask_i["width"]) + for mask_i in mask_info][-1]*default_width + # Will be shape [max_width, max_layers, num_experts] + expert_mask = np.zeros([max_width, max_layers, 0]) + for idx, mask_i in enumerate(mask_info): + if mask_i["percent_number"] < 1.0: + num_experts_in_mask = int(num_experts * mask_i["percent_number"]) + else: + num_experts_in_mask = int(mask_i["percent_number"]) + # this is ambivalent if percent_number=1 (could be either all or 1 expert) + # it looks though like the argument below takes care of that + if idx == (len(mask_info) - 1): # last position + num_experts_in_mask_tmp = num_experts - expert_mask.shape[2] + if num_experts_in_mask_tmp != num_experts_in_mask: + tf.logging.info( + "Expert layer probabilities do not evenly divide " + "the number of experts: {} {}".format( + num_experts_in_mask, num_experts_in_mask_tmp)) + num_experts_in_mask = num_experts_in_mask_tmp + mask = np.zeros([int(max_width), int(max_layers), + num_experts_in_mask]) + # Zero out the last layers of the experts. + mask[:(mask_i["width"]*default_width), :mask_i["layers"], :] = 1 + expert_mask = np.concatenate([expert_mask, mask], axis=2) # expert dim + assert expert_mask.shape[2] == num_experts + tf.logging.info("heterogeneous mask: {}".format(expert_mask)) + + # Now import the numpy mask into Mesh TF. + layers_dim = mtf.Dimension("num_expert_layers", max_layers) + width_dim = mtf.Dimension("expert_hidden", max_width) + expert_mask_tf = tf.convert_to_tensor(expert_mask) + expert_mask_mtf = mtf.import_tf_tensor( + mesh, tf_tensor=expert_mask_tf, + shape=[width_dim, layers_dim, experts_dim]) + return expert_mask_mtf