Skip to content
This repository was archived by the owner on Mar 20, 2026. It is now read-only.

Commit e1ba32a

Browse files
Naman Goyalfacebook-github-bot
authored andcommitted
added fast stats sync option (#858)
Summary: Added `--fast-stat-sync` option. This avoids pickle and achieves `~7%` more `wps` on 16 nodes. It is less flexible as it just aggregates only basic stats and it ignores the aggregate function defined by criterion. Let me know what you think myleott Pull Request resolved: fairinternal/fairseq-py#858 Differential Revision: D17398770 fbshipit-source-id: 36261a1d970e67deeda8211af8f009ef9b4f9c14
1 parent 1fd8943 commit e1ba32a

4 files changed

Lines changed: 66 additions & 11 deletions

File tree

fairseq/criterions/cross_entropy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def forward(self, model, sample, reduce=True):
3030
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
3131
logging_output = {
3232
'loss': utils.item(loss.data) if reduce else loss.data,
33+
'nll_loss': utils.item(loss.data) if reduce else loss.data,
3334
'ntokens': sample['ntokens'],
3435
'nsentences': sample['target'].size(0),
3536
'sample_size': sample_size,

fairseq/criterions/masked_lm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def forward(self, model, sample, reduce=True):
4747

4848
logging_output = {
4949
'loss': utils.item(loss.data) if reduce else loss.data,
50+
'nll_loss': utils.item(loss.data) if reduce else loss.data,
5051
'ntokens': sample['ntokens'],
5152
'nsentences': sample['nsentences'],
5253
'sample_size': sample_size,

fairseq/options.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,9 @@ def add_distributed_training_args(parser):
332332
group.add_argument('--find-unused-parameters', default=False, action='store_true',
333333
help='disable unused parameter detection (not applicable to '
334334
'no_c10d ddp-backend')
335+
group.add_argument('--fast-stat-sync', default=False, action='store_true',
336+
help='Enable fast sync of stats between nodes, this hardcodes to '
337+
'sync only some default stats from logging_output.')
335338
# fmt: on
336339
return group
337340

fairseq/trainer.py

Lines changed: 61 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@ def __init__(self, args, task, model, criterion, dummy_batch=None, oom_batch=Non
5757
self._wrapped_criterion = None
5858
self._wrapped_model = None
5959

60+
# Fast stats sync avoids memcpy and is 7% faster when tested on 16 nodes.
61+
# It is less flexible and syncs only the default stats.
62+
self._all_reduce_list = [0.0] * 6
63+
self.fast_stat_sync = args.fast_stat_sync
64+
6065
self.init_meters(args)
6166

6267
def init_meters(self, args):
@@ -292,6 +297,13 @@ def maybe_no_sync():
292297
if not ignore_grad:
293298
logging_outputs.append(logging_output)
294299
sample_sizes.append(sample_size)
300+
301+
if self.fast_stat_sync:
302+
self._all_reduce_list[0] += sample_size
303+
self._all_reduce_list[1] += logging_output.get('nsentences', 0.0)
304+
self._all_reduce_list[2] += logging_output.get('loss', 0.0)
305+
self._all_reduce_list[3] += logging_output.get('nll_loss', 0.0)
306+
self._all_reduce_list[4] += logging_output.get('ntokens', 0.0)
295307
except RuntimeError as e:
296308
if 'out of memory' in str(e):
297309
msg = (
@@ -311,20 +323,41 @@ def maybe_no_sync():
311323
else:
312324
raise e
313325

326+
if self.fast_stat_sync:
327+
self._all_reduce_list[5] += ooms
328+
329+
314330
if ooms > 0 and self._oom_batch is not None:
315331
self.handle_ooms(ooms)
316332

317333
if dummy_batch:
318334
return None
319335

320336
# gather logging outputs from all replicas
321-
if self.args.distributed_world_size > 1 and (
322-
(not self.args.use_bmuf)
323-
or (
324-
self.args.use_bmuf
325-
and (self.get_num_updates() + 1) % self.args.global_sync_iter == 0
337+
if self.fast_stat_sync:
338+
# rework all_gather_list
339+
all_reduce_list_tensor = torch.cuda.DoubleTensor(self._all_reduce_list)
340+
if self._sync_stats():
341+
torch.distributed.all_reduce(all_reduce_list_tensor)
342+
# Normalize loss and nll_loss by "sample_size"
343+
# and convert to log base 2
344+
all_reduce_list_tensor[2:4].div_(
345+
(
346+
all_reduce_list_tensor[0:1] *
347+
torch.log(torch.cuda.DoubleTensor([2]))
348+
)
326349
)
327-
):
350+
self._all_reduce_list = all_reduce_list_tensor.tolist()
351+
logging_output = {}
352+
[
353+
sample_size,
354+
logging_output['nsentences'],
355+
logging_output['loss'],
356+
logging_output['nll_loss'],
357+
logging_output['ntokens'],
358+
ooms,
359+
] = self._all_reduce_list
360+
elif self._sync_stats():
328361
logging_outputs, sample_sizes, ooms, prev_norms = \
329362
zip(*distributed_utils.all_gather_list(
330363
[logging_outputs, sample_sizes, ooms, self._prev_grad_norm],
@@ -345,11 +378,12 @@ def maybe_no_sync():
345378
self.zero_grad()
346379
return None
347380

348-
# aggregate logging outputs and sample sizes
349-
logging_output = self.task.aggregate_logging_outputs(
350-
logging_outputs, self.get_criterion()
351-
)
352-
sample_size = self.task.grad_denom(sample_sizes, self.get_criterion())
381+
if not self.fast_stat_sync:
382+
# aggregate logging outputs and sample sizes
383+
logging_output = self.task.aggregate_logging_outputs(
384+
logging_outputs, self.get_criterion()
385+
)
386+
sample_size = self.task.grad_denom(sample_sizes, self.get_criterion())
353387

354388
if not all(k in logging_output for k in ['ntokens', 'nsentences']):
355389
raise Exception((
@@ -400,6 +434,7 @@ def maybe_no_sync():
400434
self.meters['loss_scale'].reset()
401435
self.meters['loss_scale'].update(self.optimizer.scaler.loss_scale)
402436

437+
self.clear_buffered_stats()
403438
self.meters['train_wall'].stop()
404439

405440
return logging_output
@@ -484,6 +519,9 @@ def handle_ooms(self, number_of_ooms):
484519
def zero_grad(self):
485520
self.optimizer.zero_grad()
486521

522+
def clear_buffered_stats(self):
523+
self._all_reduce_list = [0.0] * 6
524+
487525
def lr_step(self, epoch, val_loss=None):
488526
"""Adjust the learning rate based on the validation loss."""
489527
self.lr_scheduler.step(epoch, val_loss)
@@ -545,3 +583,15 @@ def _set_seed(self):
545583
torch.manual_seed(seed)
546584
if self.cuda:
547585
torch.cuda.manual_seed(seed)
586+
587+
def _sync_stats(self):
588+
return (
589+
self.args.distributed_world_size > 1 and
590+
(
591+
(not self.args.use_bmuf) or
592+
(
593+
self.args.use_bmuf
594+
and (self.get_num_updates() + 1) % self.args.global_sync_iter == 0
595+
)
596+
)
597+
)

0 commit comments

Comments
 (0)