@@ -61,7 +61,9 @@ def __init__(self,
61
61
ntlb_top_k = 4 ,
62
62
output_dim = None ,
63
63
use_experts_attention = False ,
64
- z_loss = None ):
64
+ z_loss = None ,
65
+ num_hidden_splits = None ,
66
+ split_hidden_before_routing = False ):
65
67
self ._hparams = HParams (
66
68
moe_gating = moe_gating ,
67
69
moe_num_experts = num_experts ,
@@ -85,7 +87,9 @@ def __init__(self,
85
87
moe_output_dim = output_dim ,
86
88
moe_ntlb_top_k = ntlb_top_k ,
87
89
moe_use_experts_attention = use_experts_attention ,
88
- moe_z_loss = z_loss )
90
+ moe_z_loss = z_loss ,
91
+ moe_num_hidden_splits = num_hidden_splits ,
92
+ moe_split_hidden_before_routing = split_hidden_before_routing )
89
93
self ._activation = activation
90
94
91
95
def call (self , context , x , losses = None ):
@@ -327,8 +331,8 @@ def transformer_moe_layer_v1(
327
331
# We "cheat" here and look at the mesh shape and layout. This is to ensure
328
332
# that the number of groups is a multiple of the mesh dimension
329
333
# over which those groups are split.
330
- batch_and_length_dims , input_dim = (orig_inputs . shape . dims [: - 1 ],
331
- orig_inputs .shape .dims [- 1 ])
334
+ batch_and_length_dims , orig_input_dim = (
335
+ orig_inputs . shape . dims [: - 1 ], orig_inputs .shape .dims [- 1 ])
332
336
# Hack: we assume that
333
337
# "outer_batch" == replication of experts
334
338
# mesh_dim_size can be derived from mesh_shape and orig_batch_dim
@@ -348,16 +352,57 @@ def transformer_moe_layer_v1(
348
352
349
353
n = n // outer_batch_dim .size
350
354
351
- mesh_dim_size = mtf .tensor_dim_to_mesh_dim_size (layout , mesh_shape ,
352
- orig_batch_dim )
353
- num_groups , group_size = _split_into_groups (n , hparams .moe_group_size ,
354
- mesh_dim_size )
355
+ # Create num_groups and group_size dimensions
356
+ mesh_dim_size = mtf .tensor_dim_to_mesh_dim_size (
357
+ layout , mesh_shape , orig_batch_dim )
358
+ num_groups , group_size = _split_into_groups (
359
+ n , hparams .moe_group_size , mesh_dim_size )
360
+ orig_group_size_dim = mtf .Dimension ("group" , group_size )
361
+ orig_num_groups_dim = mtf .Dimension (orig_batch_dim .name , num_groups )
362
+
363
+ # The original dimensions correspond to those before splitting tokens
364
+ # into subtokens
365
+ group_size_dim = orig_group_size_dim
366
+ num_groups_dim = orig_num_groups_dim
367
+ input_dim = orig_input_dim
368
+
369
+ split_hidden_before_routing = False
370
+ split_hidden_after_routing = False
371
+ if hparams .moe_num_hidden_splits is not None :
372
+ if orig_input_dim .size % hparams .moe_num_hidden_splits :
373
+ raise ValueError ("num_hidden_splits {} must divide input_dim {}" .format (
374
+ hparams .moe_num_hidden_splits , input_dim .size ))
375
+ if output_dim .size % hparams .moe_num_hidden_splits :
376
+ raise ValueError ("num_hidden_splits {} must divide input_dim {}" .format (
377
+ hparams .moe_num_hidden_splits , input_dim .size ))
378
+ split_hidden_before_routing = hparams .moe_split_hidden_before_routing
379
+ split_hidden_after_routing = not hparams .moe_split_hidden_before_routing
380
+ hidden_dim = mtf .Dimension (
381
+ "expert_hidden" ,
382
+ hparams .moe_hidden_size // hparams .moe_num_hidden_splits )
383
+ sub_output_dim = mtf .Dimension (
384
+ output_dim .name , output_dim .size // hparams .moe_num_hidden_splits )
385
+ num_splits_dim = mtf .Dimension (
386
+ "num_splits" , hparams .moe_num_hidden_splits )
387
+
388
+ if split_hidden_before_routing :
389
+ input_dim = mtf .Dimension (
390
+ input_dim .name , input_dim .size // hparams .moe_num_hidden_splits )
391
+
392
+ # Split into groups and subtokens
393
+ inputs = mtf .reshape (
394
+ inputs , [outer_batch_dim , num_groups_dim , group_size_dim ,
395
+ num_splits_dim , input_dim ])
355
396
356
- group_size_dim = mtf .Dimension ("group" , group_size )
357
- num_groups_dim = mtf .Dimension (orig_batch_dim .name , num_groups )
397
+ inputs = mtf .transpose (
398
+ inputs , [outer_batch_dim , num_groups_dim , num_splits_dim ,
399
+ group_size_dim , input_dim ])
358
400
401
+ num_groups_dim = mtf .Dimension (
402
+ orig_batch_dim .name , num_groups * hparams .moe_num_hidden_splits )
403
+
404
+ # [outer_batch_dim, num_groups_dim.B, group_size_dim, input_dim]
359
405
moe_input_dims = [outer_batch_dim , num_groups_dim , group_size_dim , input_dim ]
360
- # OGSM Tensor
361
406
inputs = mtf .reshape (inputs , moe_input_dims )
362
407
363
408
# Each sequence sends expert_capacity positions to each expert.
@@ -373,156 +418,138 @@ def transformer_moe_layer_v1(
373
418
expert_capacity_dim = mtf .Dimension ("expert_capacity" , expert_capacity )
374
419
experts_dim_unsplit = mtf .Dimension ("expert_unsplit" , experts_dim .size )
375
420
batch_dim_unsplit = mtf .Dimension ("batch_unsplit" , num_groups_dim .size )
421
+
376
422
if nonpadding is not None :
377
423
nonpadding = mtf .zeros (
378
424
inputs .mesh , batch_and_length_dims , dtype = inputs .dtype ) + nonpadding
425
+
426
+ if split_hidden_before_routing :
427
+ nonpadding = mtf .reshape (
428
+ nonpadding ,
429
+ [outer_batch_dim , orig_num_groups_dim , orig_group_size_dim ])
430
+
431
+ # Tile num_hidden_splits times with an einsum
432
+ tiling_tensor = mtf .ones (inputs .mesh , [num_splits_dim ])
433
+ nonpadding = mtf .einsum (
434
+ [nonpadding , tiling_tensor ],
435
+ output_shape = [outer_batch_dim , orig_num_groups_dim , num_splits_dim ,
436
+ orig_group_size_dim ])
437
+
379
438
nonpadding = mtf .reshape (nonpadding , moe_input_dims [:- 1 ])
380
- if hparams .moe_gating == "top_2" :
381
- # combine_tensor,
382
- # dispatch_tensor OG`SEC Tensors
383
- # (G is generally split along mesh dim)
384
- dispatch_tensor , combine_tensor , loss = _top_2_gating (
385
- inputs = inputs ,
386
- outer_expert_dims = None ,
387
- experts_dim = experts_dim_unsplit ,
388
- expert_capacity_dim = expert_capacity_dim ,
389
- hparams = hparams ,
390
- train = train ,
391
- variable_dtype = variable_dtype ,
392
- importance = nonpadding ,
393
- num_microbatches = num_microbatches )
394
- elif hparams .moe_gating == "switch" :
395
- dispatch_tensor , combine_tensor , loss = _switch_gating (
396
- inputs = inputs ,
397
- outer_expert_dims = None ,
398
- experts_dim = experts_dim_unsplit ,
399
- expert_capacity_dim = expert_capacity_dim ,
400
- hparams = hparams ,
401
- train = train ,
402
- variable_dtype = variable_dtype ,
403
- importance = nonpadding ,
404
- num_microbatches = num_microbatches )
405
- elif hparams .moe_gating == "ntlb" :
406
- dispatch_tensor , combine_tensor , loss = _ntlb_gating (
407
- inputs = inputs ,
408
- outer_expert_dims = None ,
409
- experts_dim = experts_dim_unsplit ,
410
- expert_capacity_dim = expert_capacity_dim ,
411
- hparams = hparams ,
412
- train = train ,
413
- variable_dtype = variable_dtype ,
414
- importance = nonpadding ,
415
- num_microbatches = num_microbatches )
416
- elif hparams .moe_gating == "switch_max" :
417
- dispatch_tensor , combine_tensor , loss = _switch_max_gating (
418
- inputs = inputs ,
419
- outer_expert_dims = None ,
420
- experts_dim = experts_dim_unsplit ,
421
- expert_capacity_dim = expert_capacity_dim ,
422
- hparams = hparams ,
423
- train = train ,
424
- variable_dtype = variable_dtype ,
425
- importance = nonpadding ,
426
- num_microbatches = num_microbatches )
427
- elif hparams .moe_gating == "expert_selection" :
428
- dispatch_tensor , combine_tensor , loss = _expert_selection_gating (
429
- inputs = inputs ,
430
- outer_expert_dims = None ,
431
- experts_dim = experts_dim_unsplit ,
432
- group_size_dim = group_size_dim ,
433
- expert_capacity_dim = expert_capacity_dim ,
434
- hparams = hparams ,
435
- train = train ,
436
- variable_dtype = variable_dtype ,
437
- importance = nonpadding ,
438
- name = "expert_selection_gating" ,
439
- num_microbatches = num_microbatches )
440
- else :
441
- raise ValueError ("unknown hparams.moe_gating=%s" % hparams .moe_gating )
442
439
443
- expert_inputs = mtf .einsum ([inputs , dispatch_tensor ],
444
- mtf .Shape ([
445
- outer_batch_dim , experts_dim_unsplit ,
446
- num_groups_dim , expert_capacity_dim , input_dim
447
- ]))
440
+ # [outer_batch_dim, num_groups_dim.B, group_size_dim,
441
+ # experts_dim_unsplit, expert_capacity_dim]
442
+ gating_fn = get_gating_fn (hparams .moe_gating )
443
+ dispatch_tensor , combine_tensor , loss = gating_fn (
444
+ inputs = inputs ,
445
+ outer_expert_dims = None ,
446
+ experts_dim = experts_dim_unsplit ,
447
+ expert_capacity_dim = expert_capacity_dim ,
448
+ hparams = hparams ,
449
+ train = train ,
450
+ variable_dtype = variable_dtype ,
451
+ importance = nonpadding ,
452
+ num_microbatches = num_microbatches )
453
+
454
+ # Dispatch to the experts by reducing group_size_dim
455
+ # inputs: [outer_batch_dim, num_groups_dim.B, group_size_dim, input_dim]
456
+ # dispatch_tensor: [outer_batch_dim, num_groups_dim.B, group_size_dim,
457
+ # experts_dim_unsplit, expert_capacity_dim]
458
+ # expert_inputs: [outer_batch_dim, experts_dim_unsplit, num_groups_dim.B,
459
+ # expert_capacity_dim, input_dim]
460
+ expert_inputs_shape = [
461
+ outer_batch_dim , experts_dim_unsplit , num_groups_dim ,
462
+ expert_capacity_dim , input_dim ]
463
+ expert_inputs = mtf .einsum ([inputs , dispatch_tensor ], expert_inputs_shape )
448
464
465
+ # Split over batch -> split over experts
449
466
# Extra reshape reduces communication cost for model-parallel versions.
450
467
# For model-parallel versions, this reshape causes an mtf.slice and for non-
451
468
# model-parallel versions, this has no effect.
469
+ # expert_inputs: [outer_batch_dim, experts_dim.B, batch_dim_unsplit,
470
+ # expert_capacity_dim, input_dim or input_dim.M]
452
471
d_model_split_dim = mtf .Dimension ("d_model_split" , input_dim .size )
453
- expert_inputs = mtf .reshape (
454
- expert_inputs ,
455
- mtf .Shape ([
456
- outer_batch_dim , experts_dim , batch_dim_unsplit , expert_capacity_dim ,
457
- d_model_split_dim
458
- ]))
459
-
460
- # Split over batch -> split over experts
461
- expert_inputs = mtf .reshape (
462
- expert_inputs ,
463
- mtf .Shape ([
464
- outer_batch_dim , experts_dim , batch_dim_unsplit , expert_capacity_dim ,
465
- input_dim
466
- ]))
467
-
468
- # Now feed the expert inputs through the experts.
469
- h = mtf .layers .dense_product (
470
- expert_inputs ,
471
- reduced_dims = expert_inputs .shape .dims [- 1 :],
472
- new_dims = [hidden_dim ],
473
- expert_dims = [experts_dim ],
474
- activation_functions = activation , use_bias = False ,
475
- variable_dtype = variable_dtype , name = "wi" )
476
-
477
- if hparams .moe_dropout_rate != 0.0 :
478
- h = mtf .dropout (h , is_training = train ,
479
- keep_prob = 1.0 - hparams .moe_dropout_rate )
480
-
481
- def _compute_output (hidden , layer_name ):
482
- """Compute the output of the attention layer from the hidden vector."""
472
+ expert_inputs_shape = [
473
+ outer_batch_dim , experts_dim , batch_dim_unsplit ,
474
+ expert_capacity_dim , d_model_split_dim ]
475
+ expert_inputs = mtf .reshape (expert_inputs , expert_inputs_shape )
476
+
477
+ expert_inputs_shape = [
478
+ outer_batch_dim , experts_dim , batch_dim_unsplit ,
479
+ expert_capacity_dim , input_dim ]
480
+ expert_inputs = mtf .reshape (expert_inputs , expert_inputs_shape )
481
+
482
+ def _apply_experts (x , output_dim , hidden_dim ):
483
+ # x: [outer_batch_dim, experts_dim.B, batch_dim_unsplit,
484
+ # expert_capacity_dim, input_dim]
485
+ h = mtf .layers .dense_product (
486
+ x ,
487
+ reduced_dims = x .shape .dims [- 1 :],
488
+ new_dims = [hidden_dim ],
489
+ expert_dims = [experts_dim ],
490
+ activation_functions = activation , use_bias = False ,
491
+ variable_dtype = variable_dtype , name = "wi" )
492
+
493
+ if hparams .moe_dropout_rate != 0.0 :
494
+ h = mtf .dropout (h , is_training = train ,
495
+ keep_prob = 1.0 - hparams .moe_dropout_rate )
483
496
expert_output = mtf .layers .dense (
484
- hidden , output_dim , expert_dims = [experts_dim ], use_bias = False ,
485
- reduced_dims = hidden .shape .dims [- 1 :], variable_dtype = variable_dtype ,
486
- name = layer_name )
487
-
488
- # Extra reshape reduces communication cost for model-parallel versions.
489
- # For model-parallel versions, this reshape causes an mtf.slice and for non-
490
- # model-parallel versions, this has no effect.
491
- expert_output = mtf .reshape (
492
- expert_output ,
493
- mtf .Shape ([
494
- outer_batch_dim , experts_dim_unsplit , num_groups_dim ,
495
- expert_capacity_dim , d_model_split_dim
496
- ]))
497
-
498
- # Split over experts -> split over batch
497
+ h , output_dim , expert_dims = [experts_dim ], use_bias = False ,
498
+ reduced_dims = h .shape .dims [- 1 :], variable_dtype = variable_dtype ,
499
+ name = "wo" )
500
+
501
+ return expert_output
502
+
503
+ if split_hidden_after_routing :
504
+ input_dim = mtf .Dimension (
505
+ input_dim .name , input_dim .size // hparams .moe_num_hidden_splits )
506
+ expert_inputs = mtf .reshape (
507
+ expert_inputs , expert_inputs .shape [:- 1 ] + [num_splits_dim , input_dim ])
508
+ expert_output = _apply_experts (expert_inputs , sub_output_dim , hidden_dim )
509
+ # Concat sub_tokens into tokens
499
510
expert_output = mtf .reshape (
500
- expert_output ,
501
- mtf .Shape ([
502
- outer_batch_dim ,
503
- experts_dim_unsplit ,
504
- num_groups_dim ,
505
- expert_capacity_dim ,
506
- output_dim ,
507
- ]))
508
- moe_output_dims = moe_input_dims [:- 1 ] + [output_dim ]
509
- output = mtf .einsum ([expert_output , combine_tensor ],
510
- mtf .Shape (moe_output_dims ))
511
- output = mtf .reshape (output , batch_and_length_dims + [output_dim ])
512
- return output
513
-
514
- if hparams .moe_use_experts_attention :
515
- # We share k_h and v_h with no degradation in performance
516
- q_h , k_h = h , h
517
- outputs = []
518
- q = _compute_output (q_h , layer_name = "q_wo" )
519
- k = _compute_output (k_h , layer_name = "k_wo" )
520
- outputs .append (q )
521
- outputs .append (k )
522
- return outputs , loss * hparams .moe_loss_coef
511
+ expert_output , expert_output .shape [:- 2 ] + [output_dim ])
512
+ elif split_hidden_before_routing :
513
+ expert_output = _apply_experts (expert_inputs , sub_output_dim , hidden_dim )
523
514
else :
524
- output = _compute_output (h , layer_name = "wo" )
525
- return output , loss * hparams .moe_loss_coef
515
+ expert_output = _apply_experts (expert_inputs , output_dim , hidden_dim )
516
+
517
+ # Extra reshape reduces communication cost for model-parallel versions.
518
+ # For model-parallel versions, this reshape causes an mtf.slice and for non-
519
+ # model-parallel versions, this has no effect.
520
+ expert_output_shape = [
521
+ outer_batch_dim , experts_dim_unsplit , num_groups_dim ,
522
+ expert_capacity_dim , d_model_split_dim ]
523
+ expert_output = mtf .reshape (expert_output , expert_output_shape )
524
+
525
+ # Split over experts -> split over batch
526
+ expert_output_shape = [
527
+ outer_batch_dim , experts_dim_unsplit , num_groups_dim ,
528
+ expert_capacity_dim , expert_output .shape [- 1 ]]
529
+ expert_output = mtf .reshape (expert_output , expert_output_shape )
530
+
531
+ # Combine by reducing experts_dim_unsplit and expert_capacity_dim
532
+ # expert_output: [outer_batch_dim, experts_dim_unsplit, num_groups_dim,
533
+ # expert_capacity_dim, output_dim]
534
+ # combine_tensor: [outer_batch_dim, num_groups_dim.B, group_size_dim,
535
+ # experts_dim_unsplit, expert_capacity_dim]
536
+ # output: [outer_batch_dim, num_groups_dim.B, group_size_dim, input_dim]
537
+ moe_output_dims = moe_input_dims [:- 1 ] + [expert_output .shape [- 1 ]]
538
+ output = mtf .einsum ([expert_output , combine_tensor ], moe_output_dims )
539
+ # import pdb; pdb.set_trace() # pylint:disable=g-import-not-at-top
540
+
541
+ if split_hidden_before_routing :
542
+ output = mtf .reshape (
543
+ output , [output .shape [0 ], orig_num_groups_dim , num_splits_dim ] + (
544
+ output .shape [- 2 :]))
545
+ output = mtf .transpose (
546
+ output , output .shape [:2 ] + [
547
+ group_size_dim , num_splits_dim , output .shape [- 1 ]])
548
+ output = mtf .reshape (output , output .shape [:3 ] + [output_dim ])
549
+
550
+ output = mtf .reshape (output , batch_and_length_dims + [output_dim ])
551
+
552
+ return output , loss * hparams .moe_loss_coef
526
553
527
554
528
555
def transformer_moe_layer_v2 (
@@ -801,6 +828,22 @@ def transformer_moe_layer_v2(
801
828
return output , (loss_outer + loss_inner ) * hparams .moe_loss_coef
802
829
803
830
831
+ def get_gating_fn (moe_gating ):
832
+ """Factory for gating functions."""
833
+ if moe_gating == "top_2" :
834
+ return _top_2_gating
835
+ elif moe_gating == "switch" :
836
+ return _switch_gating
837
+ elif moe_gating == "ntlb" :
838
+ return _ntlb_gating
839
+ elif moe_gating == "switch_max" :
840
+ return _switch_max_gating
841
+ elif moe_gating == "expert_selection" :
842
+ return _expert_selection_gating
843
+ else :
844
+ raise ValueError ("unknown hparams.moe_gating=%s" % moe_gating )
845
+
846
+
804
847
def _ntlb_gating (inputs ,
805
848
outer_expert_dims ,
806
849
experts_dim ,
0 commit comments