Skip to content
This repository was archived by the owner on Mar 20, 2026. It is now read-only.

Commit 2a9b4ec

Browse files
Spencer Pofffacebook-github-bot
authored andcommitted
More thorough support for iterable datasets
Summary: Using PyTorch IterableDataset for streaming iterators. Such that there is a clean differentiation in interface between datasets that are streaming data and those that support indexed access. Reviewed By: myleott Differential Revision: D18438694 fbshipit-source-id: 482857d8357091ea2a6bf819535b09ba7f1a5b7d
1 parent b31849a commit 2a9b4ec

4 files changed

Lines changed: 37 additions & 5 deletions

File tree

fairseq/data/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from .dictionary import Dictionary, TruncatedDictionary
77

8-
from .fairseq_dataset import FairseqDataset
8+
from .fairseq_dataset import FairseqDataset, FairseqIterableDataset
99

1010
from .base_wrapper_dataset import BaseWrapperDataset
1111

@@ -65,6 +65,7 @@
6565
'Dictionary',
6666
'EpochBatchIterator',
6767
'FairseqDataset',
68+
'FairseqIterableDataset',
6869
'GroupedIterator',
6970
'IdDataset',
7071
'IndexedCachedDataset',

fairseq/data/fairseq_dataset.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,15 @@
77
import torch.utils.data
88

99

10-
class FairseqDataset(torch.utils.data.Dataset):
10+
class EpochListening:
11+
"""Mixin for receiving updates whenever the epoch increments."""
12+
def set_epoch(self, epoch):
13+
"""Will receive the updated epoch number at the beginning of the epoch.
14+
"""
15+
pass
16+
17+
18+
class FairseqDataset(torch.utils.data.Dataset, EpochListening):
1119
"""A dataset that provides helpers for batching."""
1220

1321
def __getitem__(self, index):
@@ -54,5 +62,11 @@ def prefetch(self, indices):
5462
"""Prefetch the data required for this epoch."""
5563
raise NotImplementedError
5664

57-
def set_epoch(self, epoch):
58-
pass
65+
66+
class FairseqIterableDataset(torch.utils.data.IterableDataset, EpochListening):
67+
"""For datasets that need to be read sequentially, usually because the data
68+
is being streamed or otherwise can't be manipulated on a single machine.
69+
"""
70+
71+
def __iter__(self):
72+
raise NotImplementedError

fairseq/data/iterators.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,15 @@ def __len__(self) -> int:
6363
raise NotImplementedError
6464

6565
def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False):
66+
"""Return a new iterator over the dataset.
67+
68+
Args:
69+
shuffle (bool, optional): shuffle batches before returning the
70+
iterator (default: True).
71+
fix_batches_to_gpus: ensure that batches are always
72+
allocated to the same shards across epochs. Requires
73+
that :attr:`dataset` supports prefetching (default: False).
74+
"""
6675
raise NotImplementedError
6776

6877
def end_of_epoch(self) -> bool:
@@ -71,20 +80,23 @@ def end_of_epoch(self) -> bool:
7180

7281
@property
7382
def iterations_in_epoch(self) -> int:
83+
"""The number of consumed batches in the current epoch."""
7484
raise NotImplementedError
7585

7686
def state_dict(self):
87+
"""Returns a dictionary containing a whole state of the iterator."""
7788
raise NotImplementedError
7889

7990
def load_state_dict(self, state_dict):
91+
"""Copies the state of the iterator from the given *state_dict*."""
8092
raise NotImplementedError
8193

8294

8395
class StreamingEpochBatchIterator(EpochBatchIterating):
8496
def __init__(
8597
self, dataset, epoch=0, num_shards=1, shard_id=0,
8698
):
87-
# assert isinstance(dataset, torch.utils.data.Dataset)
99+
assert isinstance(dataset, torch.utils.data.IterableDataset)
88100
self.dataset = dataset
89101
self.epoch = epoch
90102
self._current_epoch_iterator = None
@@ -93,6 +105,7 @@ def __init__(
93105

94106
def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False):
95107
self.epoch += 1
108+
self.dataset.set_epoch(self.epoch)
96109
self._current_epoch_iterator = CountingIterator(
97110
iterable=ShardedIterator(
98111
iterable=self.dataset,

fairseq/data/list_dataset.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ def __init__(self, dataset, sizes=None):
1212
super().__init__(dataset)
1313
self._sizes = sizes
1414

15+
def __iter__(self):
16+
for x in self.dataset:
17+
yield x
18+
1519
def collater(self, samples):
1620
return samples
1721

0 commit comments

Comments
 (0)