@@ -57,6 +57,11 @@ def __init__(self, args, task, model, criterion, dummy_batch=None, oom_batch=Non
5757 self ._wrapped_criterion = None
5858 self ._wrapped_model = None
5959
60+ # Fast stats sync avoids memcpy and is 7% faster when tested on 16 nodes.
61+ # It is less flexible and syncs only the default stats.
62+ self ._all_reduce_list = [0.0 ] * 6
63+ self .fast_stat_sync = args .fast_stat_sync
64+
6065 self .init_meters (args )
6166
6267 def init_meters (self , args ):
@@ -292,6 +297,13 @@ def maybe_no_sync():
292297 if not ignore_grad :
293298 logging_outputs .append (logging_output )
294299 sample_sizes .append (sample_size )
300+
301+ if self .fast_stat_sync :
302+ self ._all_reduce_list [0 ] += sample_size
303+ self ._all_reduce_list [1 ] += logging_output .get ('nsentences' , 0.0 )
304+ self ._all_reduce_list [2 ] += logging_output .get ('loss' , 0.0 )
305+ self ._all_reduce_list [3 ] += logging_output .get ('nll_loss' , 0.0 )
306+ self ._all_reduce_list [4 ] += logging_output .get ('ntokens' , 0.0 )
295307 except RuntimeError as e :
296308 if 'out of memory' in str (e ):
297309 msg = (
@@ -311,20 +323,41 @@ def maybe_no_sync():
311323 else :
312324 raise e
313325
326+ if self .fast_stat_sync :
327+ self ._all_reduce_list [5 ] += ooms
328+
329+
314330 if ooms > 0 and self ._oom_batch is not None :
315331 self .handle_ooms (ooms )
316332
317333 if dummy_batch :
318334 return None
319335
320336 # gather logging outputs from all replicas
321- if self .args .distributed_world_size > 1 and (
322- (not self .args .use_bmuf )
323- or (
324- self .args .use_bmuf
325- and (self .get_num_updates () + 1 ) % self .args .global_sync_iter == 0
337+ if self .fast_stat_sync :
338+ # rework all_gather_list
339+ all_reduce_list_tensor = torch .cuda .DoubleTensor (self ._all_reduce_list )
340+ if self ._sync_stats ():
341+ torch .distributed .all_reduce (all_reduce_list_tensor )
342+ # Normalize loss and nll_loss by "sample_size"
343+ # and convert to log base 2
344+ all_reduce_list_tensor [2 :4 ].div_ (
345+ (
346+ all_reduce_list_tensor [0 :1 ] *
347+ torch .log (torch .cuda .DoubleTensor ([2 ]))
348+ )
326349 )
327- ):
350+ self ._all_reduce_list = all_reduce_list_tensor .tolist ()
351+ logging_output = {}
352+ [
353+ sample_size ,
354+ logging_output ['nsentences' ],
355+ logging_output ['loss' ],
356+ logging_output ['nll_loss' ],
357+ logging_output ['ntokens' ],
358+ ooms ,
359+ ] = self ._all_reduce_list
360+ elif self ._sync_stats ():
328361 logging_outputs , sample_sizes , ooms , prev_norms = \
329362 zip (* distributed_utils .all_gather_list (
330363 [logging_outputs , sample_sizes , ooms , self ._prev_grad_norm ],
@@ -345,11 +378,12 @@ def maybe_no_sync():
345378 self .zero_grad ()
346379 return None
347380
348- # aggregate logging outputs and sample sizes
349- logging_output = self .task .aggregate_logging_outputs (
350- logging_outputs , self .get_criterion ()
351- )
352- sample_size = self .task .grad_denom (sample_sizes , self .get_criterion ())
381+ if not self .fast_stat_sync :
382+ # aggregate logging outputs and sample sizes
383+ logging_output = self .task .aggregate_logging_outputs (
384+ logging_outputs , self .get_criterion ()
385+ )
386+ sample_size = self .task .grad_denom (sample_sizes , self .get_criterion ())
353387
354388 if not all (k in logging_output for k in ['ntokens' , 'nsentences' ]):
355389 raise Exception ((
@@ -400,6 +434,7 @@ def maybe_no_sync():
400434 self .meters ['loss_scale' ].reset ()
401435 self .meters ['loss_scale' ].update (self .optimizer .scaler .loss_scale )
402436
437+ self .clear_buffered_stats ()
403438 self .meters ['train_wall' ].stop ()
404439
405440 return logging_output
@@ -484,6 +519,9 @@ def handle_ooms(self, number_of_ooms):
484519 def zero_grad (self ):
485520 self .optimizer .zero_grad ()
486521
522+ def clear_buffered_stats (self ):
523+ self ._all_reduce_list = [0.0 ] * 6
524+
487525 def lr_step (self , epoch , val_loss = None ):
488526 """Adjust the learning rate based on the validation loss."""
489527 self .lr_scheduler .step (epoch , val_loss )
@@ -545,3 +583,15 @@ def _set_seed(self):
545583 torch .manual_seed (seed )
546584 if self .cuda :
547585 torch .cuda .manual_seed (seed )
586+
587+ def _sync_stats (self ):
588+ return (
589+ self .args .distributed_world_size > 1 and
590+ (
591+ (not self .args .use_bmuf ) or
592+ (
593+ self .args .use_bmuf
594+ and (self .get_num_updates () + 1 ) % self .args .global_sync_iter == 0
595+ )
596+ )
597+ )
0 commit comments