Skip to content
This repository was archived by the owner on Mar 20, 2026. It is now read-only.

Commit 0add50c

Browse files
Naman Goyalfacebook-github-bot
authored andcommitted
allowing sharded dataset (#696)
Summary: Co-authored-by: myleott <myleott@fb.com> Changing `data` to be `str` with colon separated list for loading sharded datasets. This change is useful for loading large datasets that cannot fit into, memory. The large dataset can be sharded and then each shard is loaded in one epoch in roudrobin manner. For example, if there are `5` shards of data and `10` epochs then the shards will be iterated upon `[0, 1, 2, 3, 4, 0, 1, 2, 3, 4]`. myleott We need to look into `translation.py` as it currently already expects a list and then concats the datasets. Pull Request resolved: #696 Differential Revision: D15214049 fbshipit-source-id: 03e43a7b69c7aefada2ca668abf1eac1969fe013
1 parent 57da383 commit 0add50c

9 files changed

Lines changed: 138 additions & 89 deletions

fairseq/data/iterators.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,12 @@ class EpochBatchIterator(object):
7979
num_workers (int, optional): how many subprocesses to use for data
8080
loading. 0 means the data will be loaded in the main process
8181
(default: 0).
82+
epoch (int, optional): The epoch to start the iterator from.
8283
"""
8384

8485
def __init__(
8586
self, dataset, collate_fn, batch_sampler, seed=1, num_shards=1, shard_id=0,
86-
num_workers=0,
87+
num_workers=0, epoch=0,
8788
):
8889
assert isinstance(dataset, torch.utils.data.Dataset)
8990
self.dataset = dataset
@@ -94,7 +95,7 @@ def __init__(
9495
self.shard_id = shard_id
9596
self.num_workers = num_workers
9697

97-
self.epoch = 0
98+
self.epoch = epoch
9899
self._cur_epoch_itr = None
99100
self._next_epoch_itr = None
100101
self._supports_prefetch = getattr(dataset, 'supports_prefetch', False)

fairseq/tasks/cross_lingual_lm.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ class CrossLingualLMTask(FairseqTask):
4242
@staticmethod
4343
def add_args(parser):
4444
"""Add task-specific arguments to the parser."""
45-
parser.add_argument('data', help='path to data directory')
45+
parser.add_argument('data', help='colon separated path to data directories list, \
46+
will be iterated upon during epochs in round-robin manner')
4647
parser.add_argument('--tokens-per-sample', default=512, type=int,
4748
help='max number of total tokens over all segments'
4849
' per sample')
@@ -106,12 +107,16 @@ def setup_task(cls, args, **kwargs):
106107

107108
return cls(args, dictionary)
108109

109-
def _load_single_lang_dataset(self, split):
110+
def _load_single_lang_dataset(self, split, epoch):
110111
loaded_datasets = []
111112

113+
paths = self.args.data.split(':')
114+
assert len(paths) > 0
115+
data_path = paths[epoch % len(paths)]
116+
112117
for k in itertools.count():
113118
split_k = split + (str(k) if k > 0 else '')
114-
path = os.path.join(self.args.data, split_k)
119+
path = os.path.join(data_path, split_k)
115120

116121
if self.args.raw_text and IndexedRawTextDataset.exists(path):
117122
ds = IndexedRawTextDataset(path, self.dictionary)
@@ -124,7 +129,7 @@ def _load_single_lang_dataset(self, split):
124129
if k > 0:
125130
break
126131
else:
127-
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data))
132+
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))
128133

129134
# Since we append each block with the classification_token,
130135
# we need to effectively create blocks of length
@@ -136,7 +141,7 @@ def _load_single_lang_dataset(self, split):
136141
)
137142
)
138143

139-
print('| {} {} {} examples'.format(self.args.data, split_k, len(loaded_datasets[-1])))
144+
print('| {} {} {} examples'.format(data_path, split_k, len(loaded_datasets[-1])))
140145

141146
if len(loaded_datasets) == 1:
142147
dataset = loaded_datasets[0]
@@ -147,7 +152,7 @@ def _load_single_lang_dataset(self, split):
147152

148153
return dataset, sizes
149154

150-
def load_dataset(self, split, combine=False, **kwargs):
155+
def load_dataset(self, split, epoch=0, combine=False, **kwargs):
151156
"""Load a given dataset split.
152157
Args:
153158
split (str): name of the split (e.g., train, valid, test)
@@ -162,7 +167,7 @@ def load_dataset(self, split, combine=False, **kwargs):
162167
# Datasets are expected to be in "split.lang" format (Eg: train.en)
163168
language_split = '{}.{}'.format(split, lang)
164169

165-
block_dataset, sizes = self._load_single_lang_dataset(split=language_split)
170+
block_dataset, sizes = self._load_single_lang_dataset(split=language_split, epoch=epoch)
166171

167172
dataset_map[lang] = MaskedLMDataset(
168173
dataset=block_dataset,
@@ -182,6 +187,6 @@ def load_dataset(self, split, combine=False, **kwargs):
182187
dataset_map, default_key=self.default_key
183188
)
184189
print('| {} {} {} examples'.format(
185-
self.args.data, split, len(self.datasets[split])
190+
self.args.data.split(':')[epoch], split, len(self.datasets[split])
186191
)
187-
)
192+
)

fairseq/tasks/fairseq_task.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def dataset(self, split):
9292
def get_batch_iterator(
9393
self, dataset, max_tokens=None, max_sentences=None, max_positions=None,
9494
ignore_invalid_inputs=False, required_batch_size_multiple=1,
95-
seed=1, num_shards=1, shard_id=0, num_workers=0,
95+
seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=0,
9696
):
9797
"""
9898
Get an iterator that yields batches of data from the given dataset.
@@ -118,6 +118,7 @@ def get_batch_iterator(
118118
num_workers (int, optional): how many subprocesses to use for data
119119
loading. 0 means the data will be loaded in the main process
120120
(default: 0).
121+
epoch (int, optional): The epoch to start the iterator from.
121122
122123
Returns:
123124
~fairseq.iterators.EpochBatchIterator: a batched iterator over the
@@ -149,6 +150,7 @@ def get_batch_iterator(
149150
num_shards=num_shards,
150151
shard_id=shard_id,
151152
num_workers=num_workers,
153+
epoch=epoch,
152154
)
153155

154156
def build_model(self, args):

fairseq/tasks/language_modeling.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,9 @@ def setup_task(cls, args, **kwargs):
104104
dictionary = None
105105
output_dictionary = None
106106
if args.data:
107-
dictionary = Dictionary.load(os.path.join(args.data, 'dict.txt'))
107+
paths = args.data.split(':')
108+
assert len(paths) > 0
109+
dictionary = Dictionary.load(os.path.join(paths[0], 'dict.txt'))
108110
print('| dictionary: {} types'.format(len(dictionary)))
109111
output_dictionary = dictionary
110112
if args.output_dictionary_size >= 0:
@@ -136,7 +138,7 @@ def build_model(self, args):
136138

137139
return model
138140

139-
def load_dataset(self, split, combine=False, **kwargs):
141+
def load_dataset(self, split, epoch=0, combine=False, **kwargs):
140142
"""Load a given dataset split.
141143
142144
Args:
@@ -145,9 +147,13 @@ def load_dataset(self, split, combine=False, **kwargs):
145147

146148
loaded_datasets = []
147149

150+
paths = self.args.data.split(':')
151+
assert len(paths) > 0
152+
data_path = paths[epoch % len(paths)]
153+
148154
for k in itertools.count():
149155
split_k = split + (str(k) if k > 0 else '')
150-
path = os.path.join(self.args.data, split_k)
156+
path = os.path.join(data_path, split_k)
151157

152158
if self.args.raw_text and IndexedRawTextDataset.exists(path):
153159
ds = IndexedRawTextDataset(path, self.dictionary)
@@ -160,7 +166,7 @@ def load_dataset(self, split, combine=False, **kwargs):
160166
if k > 0:
161167
break
162168
else:
163-
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data))
169+
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))
164170

165171
loaded_datasets.append(
166172
TokenBlockDataset(
@@ -170,7 +176,7 @@ def load_dataset(self, split, combine=False, **kwargs):
170176
)
171177
)
172178

173-
print('| {} {} {} examples'.format(self.args.data, split_k, len(loaded_datasets[-1])))
179+
print('| {} {} {} examples'.format(data_path, split_k, len(loaded_datasets[-1])))
174180

175181
if not combine:
176182
break

fairseq/tasks/multilingual_translation.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,9 @@ def prepare(cls, args, **kargs):
135135
# load dictionaries
136136
dicts = OrderedDict()
137137
for lang in sorted_langs:
138-
dicts[lang] = Dictionary.load(os.path.join(args.data, 'dict.{}.txt'.format(lang)))
138+
paths = args.data.split(':')
139+
assert len(paths) > 0
140+
dicts[lang] = Dictionary.load(os.path.join(paths[0], 'dict.{}.txt'.format(lang)))
139141
if len(dicts) > 0:
140142
assert dicts[lang].pad() == dicts[sorted_langs[0]].pad()
141143
assert dicts[lang].eos() == dicts[sorted_langs[0]].eos()
@@ -185,11 +187,15 @@ def alter_dataset_langtok(self, lang_pair_dataset,
185187
new_tgt_bos=new_tgt_bos,
186188
)
187189

188-
def load_dataset(self, split, **kwargs):
190+
def load_dataset(self, split, epoch=0, **kwargs):
189191
"""Load a dataset split."""
190192

193+
paths = self.args.data.split(':')
194+
assert len(paths) > 0
195+
data_path = paths[epoch % len(paths)]
196+
191197
def split_exists(split, src, tgt, lang):
192-
filename = os.path.join(self.args.data, '{}.{}-{}.{}'.format(split, src, tgt, lang))
198+
filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
193199
if self.args.raw_text and IndexedRawTextDataset.exists(filename):
194200
return True
195201
elif not self.args.raw_text and IndexedDataset.exists(filename):
@@ -210,17 +216,17 @@ def indexed_dataset(path, dictionary):
210216
for lang_pair in self.args.lang_pairs:
211217
src, tgt = lang_pair.split('-')
212218
if split_exists(split, src, tgt, src):
213-
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, src, tgt))
219+
prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, src, tgt))
214220
elif split_exists(split, tgt, src, src):
215-
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, tgt, src))
221+
prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, tgt, src))
216222
else:
217223
continue
218224
src_datasets[lang_pair] = indexed_dataset(prefix + src, self.dicts[src])
219225
tgt_datasets[lang_pair] = indexed_dataset(prefix + tgt, self.dicts[tgt])
220-
print('| {} {} {} examples'.format(self.args.data, split, len(src_datasets[lang_pair])))
226+
print('| {} {} {} examples'.format(data_path, split, len(src_datasets[lang_pair])))
221227

222228
if len(src_datasets) == 0:
223-
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data))
229+
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))
224230

225231
def language_pair_dataset(lang_pair):
226232
src, tgt = lang_pair.split('-')

fairseq/tasks/semisupervised_translation.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -132,14 +132,18 @@ def setup_task(cls, args, **kwargs):
132132
dicts, training = MultilingualTranslationTask.prepare(args, **kwargs)
133133
return cls(args, dicts, training)
134134

135-
def load_dataset(self, split, **kwargs):
135+
def load_dataset(self, split, epoch=0, **kwargs):
136136
"""Load a dataset split."""
137137

138+
paths = self.args.data.split(':')
139+
assert len(paths) > 0
140+
data_path = paths[epoch % len(paths)]
141+
138142
def split_exists(split, src, tgt, lang):
139143
if src is not None:
140-
filename = os.path.join(self.args.data, '{}.{}-{}.{}'.format(split, src, tgt, lang))
144+
filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
141145
else:
142-
filename = os.path.join(self.args.data, '{}.{}-None.{}'.format(split, src, tgt))
146+
filename = os.path.join(data_path, '{}.{}-None.{}'.format(split, src, tgt))
143147
if self.args.raw_text and IndexedRawTextDataset.exists(filename):
144148
return True
145149
elif not self.args.raw_text and IndexedDataset.exists(filename):
@@ -162,25 +166,25 @@ def indexed_dataset(path, dictionary):
162166
for lang_pair in self.args.lang_pairs:
163167
src, tgt = lang_pair.split('-')
164168
if split_exists(split, src, tgt, src):
165-
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, src, tgt))
169+
prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, src, tgt))
166170
elif split_exists(split, tgt, src, src):
167-
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, tgt, src))
171+
prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, tgt, src))
168172
else:
169173
continue
170174
src_datasets[lang_pair] = indexed_dataset(prefix + src, self.dicts[src])
171175
tgt_datasets[lang_pair] = indexed_dataset(prefix + tgt, self.dicts[tgt])
172-
print('| parallel-{} {} {} examples'.format(self.args.data, split, len(src_datasets[lang_pair])))
176+
print('| parallel-{} {} {} examples'.format(data_path, split, len(src_datasets[lang_pair])))
173177
if len(src_datasets) == 0:
174-
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data))
178+
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))
175179

176180
# back translation datasets
177181
backtranslate_datasets = {}
178182
if (self.lambda_otf_bt > 0.0 or self.lambda_otf_bt_steps is not None) and split.startswith("train"):
179183
for lang_pair in self.args.lang_pairs:
180184
src, tgt = lang_pair.split('-')
181185
if not split_exists(split, tgt, None, tgt):
182-
raise FileNotFoundError('Dataset not found: backtranslation {} ({})'.format(split, self.args.data))
183-
filename = os.path.join(self.args.data, '{}.{}-None.{}'.format(split, tgt, tgt))
186+
raise FileNotFoundError('Dataset not found: backtranslation {} ({})'.format(split, data_path))
187+
filename = os.path.join(data_path, '{}.{}-None.{}'.format(split, tgt, tgt))
184188
dataset = indexed_dataset(filename, self.dicts[tgt])
185189
lang_pair_dataset_tgt = LanguagePairDataset(
186190
dataset,
@@ -216,7 +220,7 @@ def indexed_dataset(path, dictionary):
216220
).collater,
217221
)
218222
print('| backtranslate-{}: {} {} {} examples'.format(
219-
tgt, self.args.data, split, len(backtranslate_datasets[lang_pair]),
223+
tgt, data_path, split, len(backtranslate_datasets[lang_pair]),
220224
))
221225
self.backtranslate_datasets[lang_pair] = backtranslate_datasets[lang_pair]
222226

@@ -227,7 +231,7 @@ def indexed_dataset(path, dictionary):
227231
_, tgt = lang_pair.split('-')
228232
if not split_exists(split, tgt, None, tgt):
229233
continue
230-
filename = os.path.join(self.args.data, '{}.{}-None.{}'.format(split, tgt, tgt))
234+
filename = os.path.join(data_path, '{}.{}-None.{}'.format(split, tgt, tgt))
231235
tgt_dataset1 = indexed_dataset(filename, self.dicts[tgt])
232236
tgt_dataset2 = indexed_dataset(filename, self.dicts[tgt])
233237
noising_dataset = NoisingDataset(
@@ -255,7 +259,7 @@ def indexed_dataset(path, dictionary):
255259
tgt_lang=tgt,
256260
)
257261
print('| denoising-{}: {} {} {} examples'.format(
258-
tgt, self.args.data, split, len(noising_datasets[lang_pair]),
262+
tgt, data_path, split, len(noising_datasets[lang_pair]),
259263
))
260264

261265
def language_pair_dataset(lang_pair):

0 commit comments

Comments
 (0)