Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit 16ad1b2

Browse files
author
Mesh TensorFlow Team
committed
Splitting tokens when routing
PiperOrigin-RevId: 378316002
1 parent 54b01b4 commit 16ad1b2

File tree

1 file changed

+191
-148
lines changed
  • mesh_tensorflow/transformer

1 file changed

+191
-148
lines changed

mesh_tensorflow/transformer/moe.py

Lines changed: 191 additions & 148 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@ def __init__(self,
6161
ntlb_top_k=4,
6262
output_dim=None,
6363
use_experts_attention=False,
64-
z_loss=None):
64+
z_loss=None,
65+
num_hidden_splits=None,
66+
split_hidden_before_routing=False):
6567
self._hparams = HParams(
6668
moe_gating=moe_gating,
6769
moe_num_experts=num_experts,
@@ -85,7 +87,9 @@ def __init__(self,
8587
moe_output_dim=output_dim,
8688
moe_ntlb_top_k=ntlb_top_k,
8789
moe_use_experts_attention=use_experts_attention,
88-
moe_z_loss=z_loss)
90+
moe_z_loss=z_loss,
91+
moe_num_hidden_splits=num_hidden_splits,
92+
moe_split_hidden_before_routing=split_hidden_before_routing)
8993
self._activation = activation
9094

9195
def call(self, context, x, losses=None):
@@ -327,8 +331,8 @@ def transformer_moe_layer_v1(
327331
# We "cheat" here and look at the mesh shape and layout. This is to ensure
328332
# that the number of groups is a multiple of the mesh dimension
329333
# over which those groups are split.
330-
batch_and_length_dims, input_dim = (orig_inputs.shape.dims[:-1],
331-
orig_inputs.shape.dims[-1])
334+
batch_and_length_dims, orig_input_dim = (
335+
orig_inputs.shape.dims[:-1], orig_inputs.shape.dims[-1])
332336
# Hack: we assume that
333337
# "outer_batch" == replication of experts
334338
# mesh_dim_size can be derived from mesh_shape and orig_batch_dim
@@ -348,16 +352,57 @@ def transformer_moe_layer_v1(
348352

349353
n = n // outer_batch_dim.size
350354

351-
mesh_dim_size = mtf.tensor_dim_to_mesh_dim_size(layout, mesh_shape,
352-
orig_batch_dim)
353-
num_groups, group_size = _split_into_groups(n, hparams.moe_group_size,
354-
mesh_dim_size)
355+
# Create num_groups and group_size dimensions
356+
mesh_dim_size = mtf.tensor_dim_to_mesh_dim_size(
357+
layout, mesh_shape, orig_batch_dim)
358+
num_groups, group_size = _split_into_groups(
359+
n, hparams.moe_group_size, mesh_dim_size)
360+
orig_group_size_dim = mtf.Dimension("group", group_size)
361+
orig_num_groups_dim = mtf.Dimension(orig_batch_dim.name, num_groups)
362+
363+
# The original dimensions correspond to those before splitting tokens
364+
# into subtokens
365+
group_size_dim = orig_group_size_dim
366+
num_groups_dim = orig_num_groups_dim
367+
input_dim = orig_input_dim
368+
369+
split_hidden_before_routing = False
370+
split_hidden_after_routing = False
371+
if hparams.moe_num_hidden_splits is not None:
372+
if orig_input_dim.size % hparams.moe_num_hidden_splits:
373+
raise ValueError("num_hidden_splits {} must divide input_dim {}".format(
374+
hparams.moe_num_hidden_splits, input_dim.size))
375+
if output_dim.size % hparams.moe_num_hidden_splits:
376+
raise ValueError("num_hidden_splits {} must divide input_dim {}".format(
377+
hparams.moe_num_hidden_splits, input_dim.size))
378+
split_hidden_before_routing = hparams.moe_split_hidden_before_routing
379+
split_hidden_after_routing = not hparams.moe_split_hidden_before_routing
380+
hidden_dim = mtf.Dimension(
381+
"expert_hidden",
382+
hparams.moe_hidden_size // hparams.moe_num_hidden_splits)
383+
sub_output_dim = mtf.Dimension(
384+
output_dim.name, output_dim.size // hparams.moe_num_hidden_splits)
385+
num_splits_dim = mtf.Dimension(
386+
"num_splits", hparams.moe_num_hidden_splits)
387+
388+
if split_hidden_before_routing:
389+
input_dim = mtf.Dimension(
390+
input_dim.name, input_dim.size // hparams.moe_num_hidden_splits)
391+
392+
# Split into groups and subtokens
393+
inputs = mtf.reshape(
394+
inputs, [outer_batch_dim, num_groups_dim, group_size_dim,
395+
num_splits_dim, input_dim])
355396

356-
group_size_dim = mtf.Dimension("group", group_size)
357-
num_groups_dim = mtf.Dimension(orig_batch_dim.name, num_groups)
397+
inputs = mtf.transpose(
398+
inputs, [outer_batch_dim, num_groups_dim, num_splits_dim,
399+
group_size_dim, input_dim])
358400

401+
num_groups_dim = mtf.Dimension(
402+
orig_batch_dim.name, num_groups * hparams.moe_num_hidden_splits)
403+
404+
# [outer_batch_dim, num_groups_dim.B, group_size_dim, input_dim]
359405
moe_input_dims = [outer_batch_dim, num_groups_dim, group_size_dim, input_dim]
360-
# OGSM Tensor
361406
inputs = mtf.reshape(inputs, moe_input_dims)
362407

363408
# Each sequence sends expert_capacity positions to each expert.
@@ -373,156 +418,138 @@ def transformer_moe_layer_v1(
373418
expert_capacity_dim = mtf.Dimension("expert_capacity", expert_capacity)
374419
experts_dim_unsplit = mtf.Dimension("expert_unsplit", experts_dim.size)
375420
batch_dim_unsplit = mtf.Dimension("batch_unsplit", num_groups_dim.size)
421+
376422
if nonpadding is not None:
377423
nonpadding = mtf.zeros(
378424
inputs.mesh, batch_and_length_dims, dtype=inputs.dtype) + nonpadding
425+
426+
if split_hidden_before_routing:
427+
nonpadding = mtf.reshape(
428+
nonpadding,
429+
[outer_batch_dim, orig_num_groups_dim, orig_group_size_dim])
430+
431+
# Tile num_hidden_splits times with an einsum
432+
tiling_tensor = mtf.ones(inputs.mesh, [num_splits_dim])
433+
nonpadding = mtf.einsum(
434+
[nonpadding, tiling_tensor],
435+
output_shape=[outer_batch_dim, orig_num_groups_dim, num_splits_dim,
436+
orig_group_size_dim])
437+
379438
nonpadding = mtf.reshape(nonpadding, moe_input_dims[:-1])
380-
if hparams.moe_gating == "top_2":
381-
# combine_tensor,
382-
# dispatch_tensor OG`SEC Tensors
383-
# (G is generally split along mesh dim)
384-
dispatch_tensor, combine_tensor, loss = _top_2_gating(
385-
inputs=inputs,
386-
outer_expert_dims=None,
387-
experts_dim=experts_dim_unsplit,
388-
expert_capacity_dim=expert_capacity_dim,
389-
hparams=hparams,
390-
train=train,
391-
variable_dtype=variable_dtype,
392-
importance=nonpadding,
393-
num_microbatches=num_microbatches)
394-
elif hparams.moe_gating == "switch":
395-
dispatch_tensor, combine_tensor, loss = _switch_gating(
396-
inputs=inputs,
397-
outer_expert_dims=None,
398-
experts_dim=experts_dim_unsplit,
399-
expert_capacity_dim=expert_capacity_dim,
400-
hparams=hparams,
401-
train=train,
402-
variable_dtype=variable_dtype,
403-
importance=nonpadding,
404-
num_microbatches=num_microbatches)
405-
elif hparams.moe_gating == "ntlb":
406-
dispatch_tensor, combine_tensor, loss = _ntlb_gating(
407-
inputs=inputs,
408-
outer_expert_dims=None,
409-
experts_dim=experts_dim_unsplit,
410-
expert_capacity_dim=expert_capacity_dim,
411-
hparams=hparams,
412-
train=train,
413-
variable_dtype=variable_dtype,
414-
importance=nonpadding,
415-
num_microbatches=num_microbatches)
416-
elif hparams.moe_gating == "switch_max":
417-
dispatch_tensor, combine_tensor, loss = _switch_max_gating(
418-
inputs=inputs,
419-
outer_expert_dims=None,
420-
experts_dim=experts_dim_unsplit,
421-
expert_capacity_dim=expert_capacity_dim,
422-
hparams=hparams,
423-
train=train,
424-
variable_dtype=variable_dtype,
425-
importance=nonpadding,
426-
num_microbatches=num_microbatches)
427-
elif hparams.moe_gating == "expert_selection":
428-
dispatch_tensor, combine_tensor, loss = _expert_selection_gating(
429-
inputs=inputs,
430-
outer_expert_dims=None,
431-
experts_dim=experts_dim_unsplit,
432-
group_size_dim=group_size_dim,
433-
expert_capacity_dim=expert_capacity_dim,
434-
hparams=hparams,
435-
train=train,
436-
variable_dtype=variable_dtype,
437-
importance=nonpadding,
438-
name="expert_selection_gating",
439-
num_microbatches=num_microbatches)
440-
else:
441-
raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating)
442439

443-
expert_inputs = mtf.einsum([inputs, dispatch_tensor],
444-
mtf.Shape([
445-
outer_batch_dim, experts_dim_unsplit,
446-
num_groups_dim, expert_capacity_dim, input_dim
447-
]))
440+
# [outer_batch_dim, num_groups_dim.B, group_size_dim,
441+
# experts_dim_unsplit, expert_capacity_dim]
442+
gating_fn = get_gating_fn(hparams.moe_gating)
443+
dispatch_tensor, combine_tensor, loss = gating_fn(
444+
inputs=inputs,
445+
outer_expert_dims=None,
446+
experts_dim=experts_dim_unsplit,
447+
expert_capacity_dim=expert_capacity_dim,
448+
hparams=hparams,
449+
train=train,
450+
variable_dtype=variable_dtype,
451+
importance=nonpadding,
452+
num_microbatches=num_microbatches)
453+
454+
# Dispatch to the experts by reducing group_size_dim
455+
# inputs: [outer_batch_dim, num_groups_dim.B, group_size_dim, input_dim]
456+
# dispatch_tensor: [outer_batch_dim, num_groups_dim.B, group_size_dim,
457+
# experts_dim_unsplit, expert_capacity_dim]
458+
# expert_inputs: [outer_batch_dim, experts_dim_unsplit, num_groups_dim.B,
459+
# expert_capacity_dim, input_dim]
460+
expert_inputs_shape = [
461+
outer_batch_dim, experts_dim_unsplit, num_groups_dim,
462+
expert_capacity_dim, input_dim]
463+
expert_inputs = mtf.einsum([inputs, dispatch_tensor], expert_inputs_shape)
448464

465+
# Split over batch -> split over experts
449466
# Extra reshape reduces communication cost for model-parallel versions.
450467
# For model-parallel versions, this reshape causes an mtf.slice and for non-
451468
# model-parallel versions, this has no effect.
469+
# expert_inputs: [outer_batch_dim, experts_dim.B, batch_dim_unsplit,
470+
# expert_capacity_dim, input_dim or input_dim.M]
452471
d_model_split_dim = mtf.Dimension("d_model_split", input_dim.size)
453-
expert_inputs = mtf.reshape(
454-
expert_inputs,
455-
mtf.Shape([
456-
outer_batch_dim, experts_dim, batch_dim_unsplit, expert_capacity_dim,
457-
d_model_split_dim
458-
]))
459-
460-
# Split over batch -> split over experts
461-
expert_inputs = mtf.reshape(
462-
expert_inputs,
463-
mtf.Shape([
464-
outer_batch_dim, experts_dim, batch_dim_unsplit, expert_capacity_dim,
465-
input_dim
466-
]))
467-
468-
# Now feed the expert inputs through the experts.
469-
h = mtf.layers.dense_product(
470-
expert_inputs,
471-
reduced_dims=expert_inputs.shape.dims[-1:],
472-
new_dims=[hidden_dim],
473-
expert_dims=[experts_dim],
474-
activation_functions=activation, use_bias=False,
475-
variable_dtype=variable_dtype, name="wi")
476-
477-
if hparams.moe_dropout_rate != 0.0:
478-
h = mtf.dropout(h, is_training=train,
479-
keep_prob=1.0 - hparams.moe_dropout_rate)
480-
481-
def _compute_output(hidden, layer_name):
482-
"""Compute the output of the attention layer from the hidden vector."""
472+
expert_inputs_shape = [
473+
outer_batch_dim, experts_dim, batch_dim_unsplit,
474+
expert_capacity_dim, d_model_split_dim]
475+
expert_inputs = mtf.reshape(expert_inputs, expert_inputs_shape)
476+
477+
expert_inputs_shape = [
478+
outer_batch_dim, experts_dim, batch_dim_unsplit,
479+
expert_capacity_dim, input_dim]
480+
expert_inputs = mtf.reshape(expert_inputs, expert_inputs_shape)
481+
482+
def _apply_experts(x, output_dim, hidden_dim):
483+
# x: [outer_batch_dim, experts_dim.B, batch_dim_unsplit,
484+
# expert_capacity_dim, input_dim]
485+
h = mtf.layers.dense_product(
486+
x,
487+
reduced_dims=x.shape.dims[-1:],
488+
new_dims=[hidden_dim],
489+
expert_dims=[experts_dim],
490+
activation_functions=activation, use_bias=False,
491+
variable_dtype=variable_dtype, name="wi")
492+
493+
if hparams.moe_dropout_rate != 0.0:
494+
h = mtf.dropout(h, is_training=train,
495+
keep_prob=1.0 - hparams.moe_dropout_rate)
483496
expert_output = mtf.layers.dense(
484-
hidden, output_dim, expert_dims=[experts_dim], use_bias=False,
485-
reduced_dims=hidden.shape.dims[-1:], variable_dtype=variable_dtype,
486-
name=layer_name)
487-
488-
# Extra reshape reduces communication cost for model-parallel versions.
489-
# For model-parallel versions, this reshape causes an mtf.slice and for non-
490-
# model-parallel versions, this has no effect.
491-
expert_output = mtf.reshape(
492-
expert_output,
493-
mtf.Shape([
494-
outer_batch_dim, experts_dim_unsplit, num_groups_dim,
495-
expert_capacity_dim, d_model_split_dim
496-
]))
497-
498-
# Split over experts -> split over batch
497+
h, output_dim, expert_dims=[experts_dim], use_bias=False,
498+
reduced_dims=h.shape.dims[-1:], variable_dtype=variable_dtype,
499+
name="wo")
500+
501+
return expert_output
502+
503+
if split_hidden_after_routing:
504+
input_dim = mtf.Dimension(
505+
input_dim.name, input_dim.size // hparams.moe_num_hidden_splits)
506+
expert_inputs = mtf.reshape(
507+
expert_inputs, expert_inputs.shape[:-1] + [num_splits_dim, input_dim])
508+
expert_output = _apply_experts(expert_inputs, sub_output_dim, hidden_dim)
509+
# Concat sub_tokens into tokens
499510
expert_output = mtf.reshape(
500-
expert_output,
501-
mtf.Shape([
502-
outer_batch_dim,
503-
experts_dim_unsplit,
504-
num_groups_dim,
505-
expert_capacity_dim,
506-
output_dim,
507-
]))
508-
moe_output_dims = moe_input_dims[:-1] + [output_dim]
509-
output = mtf.einsum([expert_output, combine_tensor],
510-
mtf.Shape(moe_output_dims))
511-
output = mtf.reshape(output, batch_and_length_dims + [output_dim])
512-
return output
513-
514-
if hparams.moe_use_experts_attention:
515-
# We share k_h and v_h with no degradation in performance
516-
q_h, k_h = h, h
517-
outputs = []
518-
q = _compute_output(q_h, layer_name="q_wo")
519-
k = _compute_output(k_h, layer_name="k_wo")
520-
outputs.append(q)
521-
outputs.append(k)
522-
return outputs, loss * hparams.moe_loss_coef
511+
expert_output, expert_output.shape[:-2] + [output_dim])
512+
elif split_hidden_before_routing:
513+
expert_output = _apply_experts(expert_inputs, sub_output_dim, hidden_dim)
523514
else:
524-
output = _compute_output(h, layer_name="wo")
525-
return output, loss * hparams.moe_loss_coef
515+
expert_output = _apply_experts(expert_inputs, output_dim, hidden_dim)
516+
517+
# Extra reshape reduces communication cost for model-parallel versions.
518+
# For model-parallel versions, this reshape causes an mtf.slice and for non-
519+
# model-parallel versions, this has no effect.
520+
expert_output_shape = [
521+
outer_batch_dim, experts_dim_unsplit, num_groups_dim,
522+
expert_capacity_dim, d_model_split_dim]
523+
expert_output = mtf.reshape(expert_output, expert_output_shape)
524+
525+
# Split over experts -> split over batch
526+
expert_output_shape = [
527+
outer_batch_dim, experts_dim_unsplit, num_groups_dim,
528+
expert_capacity_dim, expert_output.shape[-1]]
529+
expert_output = mtf.reshape(expert_output, expert_output_shape)
530+
531+
# Combine by reducing experts_dim_unsplit and expert_capacity_dim
532+
# expert_output: [outer_batch_dim, experts_dim_unsplit, num_groups_dim,
533+
# expert_capacity_dim, output_dim]
534+
# combine_tensor: [outer_batch_dim, num_groups_dim.B, group_size_dim,
535+
# experts_dim_unsplit, expert_capacity_dim]
536+
# output: [outer_batch_dim, num_groups_dim.B, group_size_dim, input_dim]
537+
moe_output_dims = moe_input_dims[:-1] + [expert_output.shape[-1]]
538+
output = mtf.einsum([expert_output, combine_tensor], moe_output_dims)
539+
# import pdb; pdb.set_trace() # pylint:disable=g-import-not-at-top
540+
541+
if split_hidden_before_routing:
542+
output = mtf.reshape(
543+
output, [output.shape[0], orig_num_groups_dim, num_splits_dim] + (
544+
output.shape[-2:]))
545+
output = mtf.transpose(
546+
output, output.shape[:2] + [
547+
group_size_dim, num_splits_dim, output.shape[-1]])
548+
output = mtf.reshape(output, output.shape[:3] + [output_dim])
549+
550+
output = mtf.reshape(output, batch_and_length_dims + [output_dim])
551+
552+
return output, loss * hparams.moe_loss_coef
526553

527554

528555
def transformer_moe_layer_v2(
@@ -801,6 +828,22 @@ def transformer_moe_layer_v2(
801828
return output, (loss_outer + loss_inner) * hparams.moe_loss_coef
802829

803830

831+
def get_gating_fn(moe_gating):
832+
"""Factory for gating functions."""
833+
if moe_gating == "top_2":
834+
return _top_2_gating
835+
elif moe_gating == "switch":
836+
return _switch_gating
837+
elif moe_gating == "ntlb":
838+
return _ntlb_gating
839+
elif moe_gating == "switch_max":
840+
return _switch_max_gating
841+
elif moe_gating == "expert_selection":
842+
return _expert_selection_gating
843+
else:
844+
raise ValueError("unknown hparams.moe_gating=%s" % moe_gating)
845+
846+
804847
def _ntlb_gating(inputs,
805848
outer_expert_dims,
806849
experts_dim,

0 commit comments

Comments
 (0)