Skip to content

hint shuffling in prototype datasets rather than acutally applying it #5111

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions torchvision/prototype/datasets/_builtin/caltech.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from torchdata.datapipes.iter import (
IterDataPipe,
Mapper,
Shuffler,
Filter,
IterKeyZipper,
)
Expand All @@ -20,7 +19,7 @@
OnlineResource,
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, read_mat, hint_sharding
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, read_mat, hint_sharding, hint_shuffling
from torchvision.prototype.features import Label, BoundingBox, Feature


Expand Down Expand Up @@ -121,7 +120,7 @@ def _make_datapipe(

images_dp = Filter(images_dp, self._is_not_background_image)
images_dp = hint_sharding(images_dp)
images_dp = Shuffler(images_dp, buffer_size=INFINITE_BUFFER_SIZE)
images_dp = hint_shuffling(images_dp)

anns_dp = Filter(anns_dp, self._is_ann)

Expand Down Expand Up @@ -185,7 +184,7 @@ def _make_datapipe(
dp = resource_dps[0]
dp = Filter(dp, self._is_not_rogue_file)
dp = hint_sharding(dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = hint_shuffling(dp)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))

def _generate_categories(self, root: pathlib.Path) -> List[str]:
Expand Down
11 changes: 8 additions & 3 deletions torchvision/prototype/datasets/_builtin/celeba.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from torchdata.datapipes.iter import (
IterDataPipe,
Mapper,
Shuffler,
Filter,
Zipper,
IterKeyZipper,
Expand All @@ -19,7 +18,13 @@
OnlineResource,
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, getitem, path_accessor, hint_sharding
from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE,
getitem,
path_accessor,
hint_sharding,
hint_shuffling,
)


csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True)
Expand Down Expand Up @@ -152,7 +157,7 @@ def _make_datapipe(
splits_dp = CelebACSVParser(splits_dp, fieldnames=("image_id", "split_id"))
splits_dp = Filter(splits_dp, self._filter_split, fn_kwargs=dict(split=config.split))
splits_dp = hint_sharding(splits_dp)
splits_dp = Shuffler(splits_dp, buffer_size=INFINITE_BUFFER_SIZE)
splits_dp = hint_shuffling(splits_dp)

anns_dp = Zipper(
*[
Expand Down
5 changes: 2 additions & 3 deletions torchvision/prototype/datasets/_builtin/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
IterDataPipe,
Filter,
Mapper,
Shuffler,
)
from torchvision.prototype.datasets.decoder import raw
from torchvision.prototype.datasets.utils import (
Expand All @@ -23,7 +22,7 @@
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE,
hint_shuffling,
image_buffer_from_array,
path_comparator,
hint_sharding,
Expand Down Expand Up @@ -89,7 +88,7 @@ def _make_datapipe(
dp = Mapper(dp, self._unpickle)
dp = CifarFileReader(dp, labels_key=self._LABELS_KEY)
dp = hint_sharding(dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = hint_shuffling(dp)
return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(decoder=decoder))

def _generate_categories(self, root: pathlib.Path) -> List[str]:
Expand Down
6 changes: 3 additions & 3 deletions torchvision/prototype/datasets/_builtin/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from torchdata.datapipes.iter import (
IterDataPipe,
Mapper,
Shuffler,
Filter,
Demultiplexer,
Grouper,
Expand All @@ -31,6 +30,7 @@
getitem,
path_accessor,
hint_sharding,
hint_shuffling,
)
from torchvision.prototype.features import BoundingBox, Label, Feature
from torchvision.prototype.utils._internal import FrozenMapping
Expand Down Expand Up @@ -182,7 +182,7 @@ def _make_datapipe(

if config.annotations is None:
dp = hint_sharding(images_dp)
dp = Shuffler(dp)
dp = hint_shuffling(dp)
return Mapper(dp, self._collate_and_decode_image, fn_kwargs=dict(decoder=decoder))

meta_dp = Filter(
Expand All @@ -208,7 +208,7 @@ def _make_datapipe(
anns_meta_dp = UnBatcher(anns_meta_dp)
anns_meta_dp = Grouper(anns_meta_dp, group_key_fn=getitem("image_id"), buffer_size=INFINITE_BUFFER_SIZE)
anns_meta_dp = hint_sharding(anns_meta_dp)
anns_meta_dp = Shuffler(anns_meta_dp)
anns_meta_dp = hint_shuffling(anns_meta_dp)

anns_dp = IterKeyZipper(
anns_meta_dp,
Expand Down
9 changes: 5 additions & 4 deletions torchvision/prototype/datasets/_builtin/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, cast

import torch
from torchdata.datapipes.iter import IterDataPipe, LineReader, IterKeyZipper, Mapper, TarArchiveReader, Filter, Shuffler
from torchdata.datapipes.iter import IterDataPipe, LineReader, IterKeyZipper, Mapper, TarArchiveReader, Filter
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
Expand All @@ -21,6 +21,7 @@
getitem,
read_mat,
hint_sharding,
hint_shuffling,
)
from torchvision.prototype.features import Label
from torchvision.prototype.utils._internal import FrozenMapping
Expand Down Expand Up @@ -141,15 +142,15 @@ def _make_datapipe(
# the train archive is a tar of tars
dp = TarArchiveReader(images_dp)
dp = hint_sharding(dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = hint_shuffling(dp)
dp = Mapper(dp, self._collate_train_data)
elif config.split == "val":
devkit_dp = Filter(devkit_dp, path_comparator("name", "ILSVRC2012_validation_ground_truth.txt"))
devkit_dp = LineReader(devkit_dp, return_path=False)
devkit_dp = Mapper(devkit_dp, int)
devkit_dp = Enumerator(devkit_dp, 1)
devkit_dp = hint_sharding(devkit_dp)
devkit_dp = Shuffler(devkit_dp, buffer_size=INFINITE_BUFFER_SIZE)
devkit_dp = hint_shuffling(devkit_dp)

dp = IterKeyZipper(
devkit_dp,
Expand All @@ -161,7 +162,7 @@ def _make_datapipe(
dp = Mapper(dp, self._collate_val_data)
else: # config.split == "test"
dp = hint_sharding(images_dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = hint_shuffling(dp)
dp = Mapper(dp, self._collate_test_data)

return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/datasets/_builtin/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
Demultiplexer,
Mapper,
Zipper,
Shuffler,
)
from torchvision.prototype.datasets.decoder import raw
from torchvision.prototype.datasets.utils import (
Expand All @@ -29,6 +28,7 @@
INFINITE_BUFFER_SIZE,
fromfile,
hint_sharding,
hint_shuffling,
)
from torchvision.prototype.features import Image, Label

Expand Down Expand Up @@ -135,7 +135,7 @@ def _make_datapipe(

dp = Zipper(images_dp, labels_dp)
dp = hint_sharding(dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = hint_shuffling(dp)
return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(config=config, decoder=decoder))


Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/datasets/_builtin/sbd.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from torchdata.datapipes.iter import (
IterDataPipe,
Mapper,
Shuffler,
Demultiplexer,
Filter,
IterKeyZipper,
Expand All @@ -29,6 +28,7 @@
path_accessor,
path_comparator,
hint_sharding,
hint_shuffling,
)


Expand Down Expand Up @@ -141,7 +141,7 @@ def _make_datapipe(
split_dp = extra_split_dp
split_dp = LineReader(split_dp, decode=True)
split_dp = hint_sharding(split_dp)
split_dp = Shuffler(split_dp)
split_dp = hint_shuffling(split_dp)

dp = split_dp
for level, data_dp in enumerate((images_dp, anns_dp)):
Expand Down
5 changes: 2 additions & 3 deletions torchvision/prototype/datasets/_builtin/semeion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from torchdata.datapipes.iter import (
IterDataPipe,
Mapper,
Shuffler,
CSVParser,
)
from torchvision.prototype.datasets.decoder import raw
Expand All @@ -17,7 +16,7 @@
OnlineResource,
DatasetType,
)
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, image_buffer_from_array, hint_sharding
from torchvision.prototype.datasets.utils._internal import image_buffer_from_array, hint_sharding, hint_shuffling


class SEMEION(Dataset):
Expand Down Expand Up @@ -65,6 +64,6 @@ def _make_datapipe(
dp = resource_dps[0]
dp = CSVParser(dp, delimiter=" ")
dp = hint_sharding(dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = hint_shuffling(dp)
dp = Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
return dp
4 changes: 2 additions & 2 deletions torchvision/prototype/datasets/_builtin/voc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from torchdata.datapipes.iter import (
IterDataPipe,
Mapper,
Shuffler,
Filter,
Demultiplexer,
IterKeyZipper,
Expand All @@ -29,6 +28,7 @@
INFINITE_BUFFER_SIZE,
path_comparator,
hint_sharding,
hint_shuffling,
)

HERE = pathlib.Path(__file__).parent
Expand Down Expand Up @@ -131,7 +131,7 @@ def _make_datapipe(
split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt"))
split_dp = LineReader(split_dp, decode=True)
split_dp = hint_sharding(split_dp)
split_dp = Shuffler(split_dp, buffer_size=INFINITE_BUFFER_SIZE)
split_dp = hint_shuffling(split_dp)

dp = split_dp
for level, data_dp in enumerate((images_dp, anns_dp)):
Expand Down
6 changes: 5 additions & 1 deletion torchvision/prototype/datasets/utils/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import torch
import torch.distributed as dist
import torch.utils.data
from torchdata.datapipes.iter import IoPathFileLister, IoPathFileLoader, IterDataPipe, ShardingFilter
from torchdata.datapipes.iter import IoPathFileLister, IoPathFileLoader, IterDataPipe, ShardingFilter, Shuffler
from torchdata.datapipes.utils import StreamWrapper


Expand Down Expand Up @@ -335,3 +335,7 @@ def read_flo(file: BinaryIO) -> torch.Tensor:

def hint_sharding(datapipe: IterDataPipe[D]) -> IterDataPipe[D]:
return ShardingFilter(datapipe)


def hint_shuffling(datapipe: IterDataPipe[D]) -> IterDataPipe[D]:
return Shuffler(datapipe, default=False, buffer_size=INFINITE_BUFFER_SIZE)