Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions torchvision/prototype/datasets/_builtin/caltech.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import io
import pathlib
import re
Expand Down Expand Up @@ -132,7 +133,7 @@ def _make_datapipe(
buffer_size=INFINITE_BUFFER_SIZE,
keep_key=True,
)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))

def _generate_categories(self, root: pathlib.Path) -> List[str]:
dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name)
Expand Down Expand Up @@ -185,7 +186,7 @@ def _make_datapipe(
dp = Filter(dp, self._is_not_rogue_file)
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))

def _generate_categories(self, root: pathlib.Path) -> List[str]:
dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name)
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/datasets/_builtin/celeba.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import csv
import functools
import io
from typing import Any, Callable, Dict, List, Optional, Tuple, Iterator, Sequence

Expand Down Expand Up @@ -26,7 +27,6 @@
hint_shuffling,
)


csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True)


Expand Down Expand Up @@ -181,4 +181,4 @@ def _make_datapipe(
keep_key=True,
)
dp = IterKeyZipper(dp, anns_dp, key_fn=getitem(0), buffer_size=INFINITE_BUFFER_SIZE)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))
2 changes: 1 addition & 1 deletion torchvision/prototype/datasets/_builtin/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def _make_datapipe(
dp = CifarFileReader(dp, labels_key=self._LABELS_KEY)
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(decoder=decoder))
return Mapper(dp, functools.partial(self._collate_and_decode, decoder=decoder))

def _generate_categories(self, root: pathlib.Path) -> List[str]:
dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name)
Expand Down
5 changes: 3 additions & 2 deletions torchvision/prototype/datasets/_builtin/coco.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import io
import pathlib
import re
Expand Down Expand Up @@ -183,7 +184,7 @@ def _make_datapipe(
if config.annotations is None:
dp = hint_sharding(images_dp)
dp = hint_shuffling(dp)
return Mapper(dp, self._collate_and_decode_image, fn_kwargs=dict(decoder=decoder))
return Mapper(dp, functools.partial(self._collate_and_decode_image, decoder=decoder))

meta_dp = Filter(
meta_dp,
Expand Down Expand Up @@ -226,7 +227,7 @@ def _make_datapipe(
buffer_size=INFINITE_BUFFER_SIZE,
)
return Mapper(
dp, self._collate_and_decode_sample, fn_kwargs=dict(annotations=config.annotations, decoder=decoder)
dp, functools.partial(self._collate_and_decode_sample, annotations=config.annotations, decoder=decoder)
)

def _generate_categories(self, root: pathlib.Path) -> Tuple[Tuple[str, str]]:
Expand Down
3 changes: 2 additions & 1 deletion torchvision/prototype/datasets/_builtin/imagenet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import io
import pathlib
import re
Expand Down Expand Up @@ -165,7 +166,7 @@ def _make_datapipe(
dp = hint_shuffling(dp)
dp = Mapper(dp, self._collate_test_data)

return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))

# Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849
# and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/datasets/_builtin/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def _make_datapipe(
dp = Zipper(images_dp, labels_dp)
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(config=config, decoder=decoder))
return Mapper(dp, functools.partial(self._collate_and_decode, config=config, decoder=decoder))


class MNIST(_MNISTBase):
Expand Down
3 changes: 2 additions & 1 deletion torchvision/prototype/datasets/_builtin/sbd.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import io
import pathlib
import re
Expand Down Expand Up @@ -152,7 +153,7 @@ def _make_datapipe(
ref_key_fn=path_accessor("stem"),
buffer_size=INFINITE_BUFFER_SIZE,
)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(config=config, decoder=decoder))
return Mapper(dp, functools.partial(self._collate_and_decode_sample, config=config, decoder=decoder))

def _generate_categories(self, root: pathlib.Path) -> Tuple[str, ...]:
dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name)
Expand Down
3 changes: 2 additions & 1 deletion torchvision/prototype/datasets/_builtin/semeion.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import io
from typing import Any, Callable, Dict, List, Optional, Tuple

Expand Down Expand Up @@ -65,5 +66,5 @@ def _make_datapipe(
dp = CSVParser(dp, delimiter=" ")
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
dp = Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
dp = Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))
return dp
4 changes: 2 additions & 2 deletions torchvision/prototype/datasets/_builtin/voc.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def _make_datapipe(
buffer_size=INFINITE_BUFFER_SIZE,
)

split_dp = Filter(split_dp, self._is_in_folder, fn_kwargs=dict(name=self._SPLIT_FOLDER[config.task]))
split_dp = Filter(split_dp, functools.partial(self._is_in_folder, name=self._SPLIT_FOLDER[config.task]))
split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt"))
split_dp = LineReader(split_dp, decode=True)
split_dp = hint_sharding(split_dp)
Expand All @@ -142,4 +142,4 @@ def _make_datapipe(
ref_key_fn=path_accessor("stem"),
buffer_size=INFINITE_BUFFER_SIZE,
)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(config=config, decoder=decoder))
return Mapper(dp, functools.partial(self._collate_and_decode_sample, config=config, decoder=decoder))