Skip to content
This repository was archived by the owner on Mar 20, 2026. It is now read-only.

Commit 2f7e3f3

Browse files
Myle Ottfacebook-github-bot
authored andcommitted
Support multi-GPU validation in fairseq-validate (#2162)
Summary: Pull Request resolved: #2162 Reviewed By: ngoyal2707 Differential Revision: D21663181 Pulled By: myleott fbshipit-source-id: d01e64f97482f76bd601cd8b20232c0ef637bb8a
1 parent be5313a commit 2f7e3f3

3 files changed

Lines changed: 16 additions & 4 deletions

File tree

fairseq/criterions/adaptive_loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __init__(self, task, sentence_avg):
2323

2424
@classmethod
2525
def build_criterion(cls, args, task):
26-
if args.ddp_backend == 'c10d':
26+
if getattr(args, 'ddp_backend', None) == 'c10d':
2727
raise Exception(
2828
'AdaptiveLoss is not compatible with the c10d '
2929
'version of DistributedDataParallel. Please use '

fairseq/options.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def get_eval_lm_parser(default_task="language_modeling"):
5353
def get_validation_parser(default_task=None):
5454
parser = get_parser("Validation", default_task)
5555
add_dataset_args(parser, train=True)
56+
add_distributed_training_args(parser)
5657
group = parser.add_argument_group("Evaluation")
5758
add_common_eval_args(group)
5859
return parser

fairseq_cli/validate.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@
55
# This source code is licensed under the MIT license found in the
66
# LICENSE file in the root directory of this source tree.
77

8+
from itertools import chain
89
import logging
910
import sys
1011

1112
import torch
1213

1314
from fairseq import checkpoint_utils, distributed_utils, options, utils
1415
from fairseq.logging import metrics, progress_bar
15-
from fairseq.options import add_distributed_training_args
16+
1617

1718
logging.basicConfig(
1819
format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
@@ -32,6 +33,9 @@ def main(args, override_args=None):
3233
use_fp16 = args.fp16
3334
use_cuda = torch.cuda.is_available() and not args.cpu
3435

36+
if use_cuda:
37+
torch.cuda.set_device(args.device_id)
38+
3539
if override_args is not None:
3640
overrides = vars(override_args)
3741
overrides.update(eval(getattr(override_args, 'model_overrides', '{}')))
@@ -80,6 +84,8 @@ def main(args, override_args=None):
8084
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
8185
required_batch_size_multiple=args.required_batch_size_multiple,
8286
seed=args.seed,
87+
num_shards=args.distributed_world_size,
88+
shard_id=args.distributed_rank,
8389
num_workers=args.num_workers,
8490
).next_epoch_itr(shuffle=False)
8591
progress = progress_bar.progress_bar(
@@ -97,6 +103,13 @@ def main(args, override_args=None):
97103
progress.log(log_output, step=i)
98104
log_outputs.append(log_output)
99105

106+
if args.distributed_world_size > 1:
107+
log_outputs = distributed_utils.all_gather_list(
108+
log_outputs,
109+
max_size=getattr(args, 'all_gather_list_size', 16384),
110+
)
111+
log_outputs = list(chain.from_iterable(log_outputs))
112+
100113
with metrics.aggregate() as agg:
101114
task.reduce_metrics(log_outputs, criterion)
102115
log_output = agg.get_smoothed_values()
@@ -106,12 +119,10 @@ def main(args, override_args=None):
106119

107120
def cli_main():
108121
parser = options.get_validation_parser()
109-
add_distributed_training_args(parser)
110122
args = options.parse_args_and_arch(parser)
111123

112124
# only override args that are explicitly given on the command line
113125
override_parser = options.get_validation_parser()
114-
add_distributed_training_args(override_parser)
115126
override_args = options.parse_args_and_arch(override_parser, suppress_defaults=True)
116127

117128
distributed_utils.call_main(args, main, override_args=override_args)

0 commit comments

Comments
 (0)