@@ -132,14 +132,18 @@ def setup_task(cls, args, **kwargs):
132132 dicts , training = MultilingualTranslationTask .prepare (args , ** kwargs )
133133 return cls (args , dicts , training )
134134
135- def load_dataset (self , split , ** kwargs ):
135+ def load_dataset (self , split , epoch = 0 , ** kwargs ):
136136 """Load a dataset split."""
137137
138+ paths = self .args .data .split (':' )
139+ assert len (paths ) > 0
140+ data_path = paths [epoch % len (paths )]
141+
138142 def split_exists (split , src , tgt , lang ):
139143 if src is not None :
140- filename = os .path .join (self . args . data , '{}.{}-{}.{}' .format (split , src , tgt , lang ))
144+ filename = os .path .join (data_path , '{}.{}-{}.{}' .format (split , src , tgt , lang ))
141145 else :
142- filename = os .path .join (self . args . data , '{}.{}-None.{}' .format (split , src , tgt ))
146+ filename = os .path .join (data_path , '{}.{}-None.{}' .format (split , src , tgt ))
143147 if self .args .raw_text and IndexedRawTextDataset .exists (filename ):
144148 return True
145149 elif not self .args .raw_text and IndexedDataset .exists (filename ):
@@ -162,25 +166,25 @@ def indexed_dataset(path, dictionary):
162166 for lang_pair in self .args .lang_pairs :
163167 src , tgt = lang_pair .split ('-' )
164168 if split_exists (split , src , tgt , src ):
165- prefix = os .path .join (self . args . data , '{}.{}-{}.' .format (split , src , tgt ))
169+ prefix = os .path .join (data_path , '{}.{}-{}.' .format (split , src , tgt ))
166170 elif split_exists (split , tgt , src , src ):
167- prefix = os .path .join (self . args . data , '{}.{}-{}.' .format (split , tgt , src ))
171+ prefix = os .path .join (data_path , '{}.{}-{}.' .format (split , tgt , src ))
168172 else :
169173 continue
170174 src_datasets [lang_pair ] = indexed_dataset (prefix + src , self .dicts [src ])
171175 tgt_datasets [lang_pair ] = indexed_dataset (prefix + tgt , self .dicts [tgt ])
172- print ('| parallel-{} {} {} examples' .format (self . args . data , split , len (src_datasets [lang_pair ])))
176+ print ('| parallel-{} {} {} examples' .format (data_path , split , len (src_datasets [lang_pair ])))
173177 if len (src_datasets ) == 0 :
174- raise FileNotFoundError ('Dataset not found: {} ({})' .format (split , self . args . data ))
178+ raise FileNotFoundError ('Dataset not found: {} ({})' .format (split , data_path ))
175179
176180 # back translation datasets
177181 backtranslate_datasets = {}
178182 if (self .lambda_otf_bt > 0.0 or self .lambda_otf_bt_steps is not None ) and split .startswith ("train" ):
179183 for lang_pair in self .args .lang_pairs :
180184 src , tgt = lang_pair .split ('-' )
181185 if not split_exists (split , tgt , None , tgt ):
182- raise FileNotFoundError ('Dataset not found: backtranslation {} ({})' .format (split , self . args . data ))
183- filename = os .path .join (self . args . data , '{}.{}-None.{}' .format (split , tgt , tgt ))
186+ raise FileNotFoundError ('Dataset not found: backtranslation {} ({})' .format (split , data_path ))
187+ filename = os .path .join (data_path , '{}.{}-None.{}' .format (split , tgt , tgt ))
184188 dataset = indexed_dataset (filename , self .dicts [tgt ])
185189 lang_pair_dataset_tgt = LanguagePairDataset (
186190 dataset ,
@@ -216,7 +220,7 @@ def indexed_dataset(path, dictionary):
216220 ).collater ,
217221 )
218222 print ('| backtranslate-{}: {} {} {} examples' .format (
219- tgt , self . args . data , split , len (backtranslate_datasets [lang_pair ]),
223+ tgt , data_path , split , len (backtranslate_datasets [lang_pair ]),
220224 ))
221225 self .backtranslate_datasets [lang_pair ] = backtranslate_datasets [lang_pair ]
222226
@@ -227,7 +231,7 @@ def indexed_dataset(path, dictionary):
227231 _ , tgt = lang_pair .split ('-' )
228232 if not split_exists (split , tgt , None , tgt ):
229233 continue
230- filename = os .path .join (self . args . data , '{}.{}-None.{}' .format (split , tgt , tgt ))
234+ filename = os .path .join (data_path , '{}.{}-None.{}' .format (split , tgt , tgt ))
231235 tgt_dataset1 = indexed_dataset (filename , self .dicts [tgt ])
232236 tgt_dataset2 = indexed_dataset (filename , self .dicts [tgt ])
233237 noising_dataset = NoisingDataset (
@@ -255,7 +259,7 @@ def indexed_dataset(path, dictionary):
255259 tgt_lang = tgt ,
256260 )
257261 print ('| denoising-{}: {} {} {} examples' .format (
258- tgt , self . args . data , split , len (noising_datasets [lang_pair ]),
262+ tgt , data_path , split , len (noising_datasets [lang_pair ]),
259263 ))
260264
261265 def language_pair_dataset (lang_pair ):
0 commit comments