55# This source code is licensed under the MIT license found in the
66# LICENSE file in the root directory of this source tree.
77
8+ from itertools import chain
89import logging
910import sys
1011
1112import torch
1213
1314from fairseq import checkpoint_utils , distributed_utils , options , utils
1415from fairseq .logging import metrics , progress_bar
15- from fairseq . options import add_distributed_training_args
16+
1617
1718logging .basicConfig (
1819 format = '%(asctime)s | %(levelname)s | %(name)s | %(message)s' ,
@@ -32,6 +33,9 @@ def main(args, override_args=None):
3233 use_fp16 = args .fp16
3334 use_cuda = torch .cuda .is_available () and not args .cpu
3435
36+ if use_cuda :
37+ torch .cuda .set_device (args .device_id )
38+
3539 if override_args is not None :
3640 overrides = vars (override_args )
3741 overrides .update (eval (getattr (override_args , 'model_overrides' , '{}' )))
@@ -80,6 +84,8 @@ def main(args, override_args=None):
8084 ignore_invalid_inputs = args .skip_invalid_size_inputs_valid_test ,
8185 required_batch_size_multiple = args .required_batch_size_multiple ,
8286 seed = args .seed ,
87+ num_shards = args .distributed_world_size ,
88+ shard_id = args .distributed_rank ,
8389 num_workers = args .num_workers ,
8490 ).next_epoch_itr (shuffle = False )
8591 progress = progress_bar .progress_bar (
@@ -97,6 +103,13 @@ def main(args, override_args=None):
97103 progress .log (log_output , step = i )
98104 log_outputs .append (log_output )
99105
106+ if args .distributed_world_size > 1 :
107+ log_outputs = distributed_utils .all_gather_list (
108+ log_outputs ,
109+ max_size = getattr (args , 'all_gather_list_size' , 16384 ),
110+ )
111+ log_outputs = list (chain .from_iterable (log_outputs ))
112+
100113 with metrics .aggregate () as agg :
101114 task .reduce_metrics (log_outputs , criterion )
102115 log_output = agg .get_smoothed_values ()
@@ -106,12 +119,10 @@ def main(args, override_args=None):
106119
107120def cli_main ():
108121 parser = options .get_validation_parser ()
109- add_distributed_training_args (parser )
110122 args = options .parse_args_and_arch (parser )
111123
112124 # only override args that are explicitly given on the command line
113125 override_parser = options .get_validation_parser ()
114- add_distributed_training_args (override_parser )
115126 override_args = options .parse_args_and_arch (override_parser , suppress_defaults = True )
116127
117128 distributed_utils .call_main (args , main , override_args = override_args )
0 commit comments