Skip to content

Commit ac561bc

Browse files
authored
hint shuffling in prototype datasets rather than acutally applying it (#5111)
* hint shuffling rather actually shuffle prototype datasets * cleanup after merge
1 parent eac3dc7 commit ac561bc

File tree

10 files changed

+34
-27
lines changed

10 files changed

+34
-27
lines changed

torchvision/prototype/datasets/_builtin/caltech.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from torchdata.datapipes.iter import (
99
IterDataPipe,
1010
Mapper,
11-
Shuffler,
1211
Filter,
1312
IterKeyZipper,
1413
)
@@ -20,7 +19,7 @@
2019
OnlineResource,
2120
DatasetType,
2221
)
23-
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, read_mat, hint_sharding
22+
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, read_mat, hint_sharding, hint_shuffling
2423
from torchvision.prototype.features import Label, BoundingBox, Feature
2524

2625

@@ -121,7 +120,7 @@ def _make_datapipe(
121120

122121
images_dp = Filter(images_dp, self._is_not_background_image)
123122
images_dp = hint_sharding(images_dp)
124-
images_dp = Shuffler(images_dp, buffer_size=INFINITE_BUFFER_SIZE)
123+
images_dp = hint_shuffling(images_dp)
125124

126125
anns_dp = Filter(anns_dp, self._is_ann)
127126

@@ -185,7 +184,7 @@ def _make_datapipe(
185184
dp = resource_dps[0]
186185
dp = Filter(dp, self._is_not_rogue_file)
187186
dp = hint_sharding(dp)
188-
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
187+
dp = hint_shuffling(dp)
189188
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
190189

191190
def _generate_categories(self, root: pathlib.Path) -> List[str]:

torchvision/prototype/datasets/_builtin/celeba.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from torchdata.datapipes.iter import (
77
IterDataPipe,
88
Mapper,
9-
Shuffler,
109
Filter,
1110
Zipper,
1211
IterKeyZipper,
@@ -19,7 +18,13 @@
1918
OnlineResource,
2019
DatasetType,
2120
)
22-
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, getitem, path_accessor, hint_sharding
21+
from torchvision.prototype.datasets.utils._internal import (
22+
INFINITE_BUFFER_SIZE,
23+
getitem,
24+
path_accessor,
25+
hint_sharding,
26+
hint_shuffling,
27+
)
2328

2429

2530
csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True)
@@ -152,7 +157,7 @@ def _make_datapipe(
152157
splits_dp = CelebACSVParser(splits_dp, fieldnames=("image_id", "split_id"))
153158
splits_dp = Filter(splits_dp, self._filter_split, fn_kwargs=dict(split=config.split))
154159
splits_dp = hint_sharding(splits_dp)
155-
splits_dp = Shuffler(splits_dp, buffer_size=INFINITE_BUFFER_SIZE)
160+
splits_dp = hint_shuffling(splits_dp)
156161

157162
anns_dp = Zipper(
158163
*[

torchvision/prototype/datasets/_builtin/cifar.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
IterDataPipe,
1212
Filter,
1313
Mapper,
14-
Shuffler,
1514
)
1615
from torchvision.prototype.datasets.decoder import raw
1716
from torchvision.prototype.datasets.utils import (
@@ -23,7 +22,7 @@
2322
DatasetType,
2423
)
2524
from torchvision.prototype.datasets.utils._internal import (
26-
INFINITE_BUFFER_SIZE,
25+
hint_shuffling,
2726
image_buffer_from_array,
2827
path_comparator,
2928
hint_sharding,
@@ -89,7 +88,7 @@ def _make_datapipe(
8988
dp = Mapper(dp, self._unpickle)
9089
dp = CifarFileReader(dp, labels_key=self._LABELS_KEY)
9190
dp = hint_sharding(dp)
92-
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
91+
dp = hint_shuffling(dp)
9392
return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(decoder=decoder))
9493

9594
def _generate_categories(self, root: pathlib.Path) -> List[str]:

torchvision/prototype/datasets/_builtin/coco.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from torchdata.datapipes.iter import (
99
IterDataPipe,
1010
Mapper,
11-
Shuffler,
1211
Filter,
1312
Demultiplexer,
1413
Grouper,
@@ -31,6 +30,7 @@
3130
getitem,
3231
path_accessor,
3332
hint_sharding,
33+
hint_shuffling,
3434
)
3535
from torchvision.prototype.features import BoundingBox, Label, Feature
3636
from torchvision.prototype.utils._internal import FrozenMapping
@@ -182,7 +182,7 @@ def _make_datapipe(
182182

183183
if config.annotations is None:
184184
dp = hint_sharding(images_dp)
185-
dp = Shuffler(dp)
185+
dp = hint_shuffling(dp)
186186
return Mapper(dp, self._collate_and_decode_image, fn_kwargs=dict(decoder=decoder))
187187

188188
meta_dp = Filter(
@@ -208,7 +208,7 @@ def _make_datapipe(
208208
anns_meta_dp = UnBatcher(anns_meta_dp)
209209
anns_meta_dp = Grouper(anns_meta_dp, group_key_fn=getitem("image_id"), buffer_size=INFINITE_BUFFER_SIZE)
210210
anns_meta_dp = hint_sharding(anns_meta_dp)
211-
anns_meta_dp = Shuffler(anns_meta_dp)
211+
anns_meta_dp = hint_shuffling(anns_meta_dp)
212212

213213
anns_dp = IterKeyZipper(
214214
anns_meta_dp,

torchvision/prototype/datasets/_builtin/imagenet.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
55

66
import torch
7-
from torchdata.datapipes.iter import IterDataPipe, LineReader, IterKeyZipper, Mapper, TarArchiveReader, Filter, Shuffler
7+
from torchdata.datapipes.iter import IterDataPipe, LineReader, IterKeyZipper, Mapper, TarArchiveReader, Filter
88
from torchvision.prototype.datasets.utils import (
99
Dataset,
1010
DatasetConfig,
@@ -21,6 +21,7 @@
2121
getitem,
2222
read_mat,
2323
hint_sharding,
24+
hint_shuffling,
2425
)
2526
from torchvision.prototype.features import Label
2627
from torchvision.prototype.utils._internal import FrozenMapping
@@ -141,15 +142,15 @@ def _make_datapipe(
141142
# the train archive is a tar of tars
142143
dp = TarArchiveReader(images_dp)
143144
dp = hint_sharding(dp)
144-
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
145+
dp = hint_shuffling(dp)
145146
dp = Mapper(dp, self._collate_train_data)
146147
elif config.split == "val":
147148
devkit_dp = Filter(devkit_dp, path_comparator("name", "ILSVRC2012_validation_ground_truth.txt"))
148149
devkit_dp = LineReader(devkit_dp, return_path=False)
149150
devkit_dp = Mapper(devkit_dp, int)
150151
devkit_dp = Enumerator(devkit_dp, 1)
151152
devkit_dp = hint_sharding(devkit_dp)
152-
devkit_dp = Shuffler(devkit_dp, buffer_size=INFINITE_BUFFER_SIZE)
153+
devkit_dp = hint_shuffling(devkit_dp)
153154

154155
dp = IterKeyZipper(
155156
devkit_dp,
@@ -161,7 +162,7 @@ def _make_datapipe(
161162
dp = Mapper(dp, self._collate_val_data)
162163
else: # config.split == "test"
163164
dp = hint_sharding(images_dp)
164-
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
165+
dp = hint_shuffling(dp)
165166
dp = Mapper(dp, self._collate_test_data)
166167

167168
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))

torchvision/prototype/datasets/_builtin/mnist.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
Demultiplexer,
1313
Mapper,
1414
Zipper,
15-
Shuffler,
1615
)
1716
from torchvision.prototype.datasets.decoder import raw
1817
from torchvision.prototype.datasets.utils import (
@@ -29,6 +28,7 @@
2928
INFINITE_BUFFER_SIZE,
3029
fromfile,
3130
hint_sharding,
31+
hint_shuffling,
3232
)
3333
from torchvision.prototype.features import Image, Label
3434

@@ -135,7 +135,7 @@ def _make_datapipe(
135135

136136
dp = Zipper(images_dp, labels_dp)
137137
dp = hint_sharding(dp)
138-
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
138+
dp = hint_shuffling(dp)
139139
return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(config=config, decoder=decoder))
140140

141141

torchvision/prototype/datasets/_builtin/sbd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from torchdata.datapipes.iter import (
99
IterDataPipe,
1010
Mapper,
11-
Shuffler,
1211
Demultiplexer,
1312
Filter,
1413
IterKeyZipper,
@@ -29,6 +28,7 @@
2928
path_accessor,
3029
path_comparator,
3130
hint_sharding,
31+
hint_shuffling,
3232
)
3333

3434

@@ -141,7 +141,7 @@ def _make_datapipe(
141141
split_dp = extra_split_dp
142142
split_dp = LineReader(split_dp, decode=True)
143143
split_dp = hint_sharding(split_dp)
144-
split_dp = Shuffler(split_dp)
144+
split_dp = hint_shuffling(split_dp)
145145

146146
dp = split_dp
147147
for level, data_dp in enumerate((images_dp, anns_dp)):

torchvision/prototype/datasets/_builtin/semeion.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from torchdata.datapipes.iter import (
66
IterDataPipe,
77
Mapper,
8-
Shuffler,
98
CSVParser,
109
)
1110
from torchvision.prototype.datasets.decoder import raw
@@ -17,7 +16,7 @@
1716
OnlineResource,
1817
DatasetType,
1918
)
20-
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, image_buffer_from_array, hint_sharding
19+
from torchvision.prototype.datasets.utils._internal import image_buffer_from_array, hint_sharding, hint_shuffling
2120

2221

2322
class SEMEION(Dataset):
@@ -65,6 +64,6 @@ def _make_datapipe(
6564
dp = resource_dps[0]
6665
dp = CSVParser(dp, delimiter=" ")
6766
dp = hint_sharding(dp)
68-
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
67+
dp = hint_shuffling(dp)
6968
dp = Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
7069
return dp

torchvision/prototype/datasets/_builtin/voc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from torchdata.datapipes.iter import (
99
IterDataPipe,
1010
Mapper,
11-
Shuffler,
1211
Filter,
1312
Demultiplexer,
1413
IterKeyZipper,
@@ -29,6 +28,7 @@
2928
INFINITE_BUFFER_SIZE,
3029
path_comparator,
3130
hint_sharding,
31+
hint_shuffling,
3232
)
3333

3434
HERE = pathlib.Path(__file__).parent
@@ -131,7 +131,7 @@ def _make_datapipe(
131131
split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt"))
132132
split_dp = LineReader(split_dp, decode=True)
133133
split_dp = hint_sharding(split_dp)
134-
split_dp = Shuffler(split_dp, buffer_size=INFINITE_BUFFER_SIZE)
134+
split_dp = hint_shuffling(split_dp)
135135

136136
dp = split_dp
137137
for level, data_dp in enumerate((images_dp, anns_dp)):

torchvision/prototype/datasets/utils/_internal.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
import torch
3131
import torch.distributed as dist
3232
import torch.utils.data
33-
from torchdata.datapipes.iter import IoPathFileLister, IoPathFileLoader, IterDataPipe, ShardingFilter
33+
from torchdata.datapipes.iter import IoPathFileLister, IoPathFileLoader, IterDataPipe, ShardingFilter, Shuffler
3434
from torchdata.datapipes.utils import StreamWrapper
3535

3636

@@ -335,3 +335,7 @@ def read_flo(file: BinaryIO) -> torch.Tensor:
335335

336336
def hint_sharding(datapipe: IterDataPipe[D]) -> IterDataPipe[D]:
337337
return ShardingFilter(datapipe)
338+
339+
340+
def hint_shuffling(datapipe: IterDataPipe[D]) -> IterDataPipe[D]:
341+
return Shuffler(datapipe, default=False, buffer_size=INFINITE_BUFFER_SIZE)

0 commit comments

Comments
 (0)