Skip to content

Commit 1efb567

Browse files
authored
remove fn_kwargs from Filter and Mapper datapipes (#5113)
* remove fn_kwargs from Filter and Mapper datapipes * fix leftovers
1 parent 40be657 commit 1efb567

File tree

10 files changed

+30
-19
lines changed

10 files changed

+30
-19
lines changed

torchvision/prototype/datasets/_builtin/caltech.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import functools
12
import io
23
import pathlib
34
import re
@@ -132,7 +133,7 @@ def _make_datapipe(
132133
buffer_size=INFINITE_BUFFER_SIZE,
133134
keep_key=True,
134135
)
135-
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
136+
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))
136137

137138
def _generate_categories(self, root: pathlib.Path) -> List[str]:
138139
dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name)
@@ -185,7 +186,7 @@ def _make_datapipe(
185186
dp = Filter(dp, self._is_not_rogue_file)
186187
dp = hint_sharding(dp)
187188
dp = hint_shuffling(dp)
188-
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
189+
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))
189190

190191
def _generate_categories(self, root: pathlib.Path) -> List[str]:
191192
dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name)

torchvision/prototype/datasets/_builtin/celeba.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import csv
2+
import functools
23
import io
34
from typing import Any, Callable, Dict, List, Optional, Tuple, Iterator, Sequence
45

@@ -26,7 +27,6 @@
2627
hint_shuffling,
2728
)
2829

29-
3030
csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True)
3131

3232

@@ -155,7 +155,7 @@ def _make_datapipe(
155155
splits_dp, images_dp, identities_dp, attributes_dp, bboxes_dp, landmarks_dp = resource_dps
156156

157157
splits_dp = CelebACSVParser(splits_dp, fieldnames=("image_id", "split_id"))
158-
splits_dp = Filter(splits_dp, self._filter_split, fn_kwargs=dict(split=config.split))
158+
splits_dp = Filter(splits_dp, functools.partial(self._filter_split, split=config.split))
159159
splits_dp = hint_sharding(splits_dp)
160160
splits_dp = hint_shuffling(splits_dp)
161161

@@ -181,4 +181,4 @@ def _make_datapipe(
181181
keep_key=True,
182182
)
183183
dp = IterKeyZipper(dp, anns_dp, key_fn=getitem(0), buffer_size=INFINITE_BUFFER_SIZE)
184-
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
184+
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))

torchvision/prototype/datasets/_builtin/cifar.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def _make_datapipe(
8989
dp = CifarFileReader(dp, labels_key=self._LABELS_KEY)
9090
dp = hint_sharding(dp)
9191
dp = hint_shuffling(dp)
92-
return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(decoder=decoder))
92+
return Mapper(dp, functools.partial(self._collate_and_decode, decoder=decoder))
9393

9494
def _generate_categories(self, root: pathlib.Path) -> List[str]:
9595
dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name)

torchvision/prototype/datasets/_builtin/coco.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import functools
12
import io
23
import pathlib
34
import re
@@ -183,12 +184,16 @@ def _make_datapipe(
183184
if config.annotations is None:
184185
dp = hint_sharding(images_dp)
185186
dp = hint_shuffling(dp)
186-
return Mapper(dp, self._collate_and_decode_image, fn_kwargs=dict(decoder=decoder))
187+
return Mapper(dp, functools.partial(self._collate_and_decode_image, decoder=decoder))
187188

188189
meta_dp = Filter(
189190
meta_dp,
190-
self._filter_meta_files,
191-
fn_kwargs=dict(split=config.split, year=config.year, annotations=config.annotations),
191+
functools.partial(
192+
self._filter_meta_files,
193+
split=config.split,
194+
year=config.year,
195+
annotations=config.annotations,
196+
),
192197
)
193198
meta_dp = JsonParser(meta_dp)
194199
meta_dp = Mapper(meta_dp, getitem(1))
@@ -226,7 +231,7 @@ def _make_datapipe(
226231
buffer_size=INFINITE_BUFFER_SIZE,
227232
)
228233
return Mapper(
229-
dp, self._collate_and_decode_sample, fn_kwargs=dict(annotations=config.annotations, decoder=decoder)
234+
dp, functools.partial(self._collate_and_decode_sample, annotations=config.annotations, decoder=decoder)
230235
)
231236

232237
def _generate_categories(self, root: pathlib.Path) -> Tuple[Tuple[str, str]]:
@@ -235,7 +240,8 @@ def _generate_categories(self, root: pathlib.Path) -> Tuple[Tuple[str, str]]:
235240

236241
dp = resources[1].load(pathlib.Path(root) / self.name)
237242
dp = Filter(
238-
dp, self._filter_meta_files, fn_kwargs=dict(split=config.split, year=config.year, annotations="instances")
243+
dp,
244+
functools.partial(self._filter_meta_files, split=config.split, year=config.year, annotations="instances"),
239245
)
240246
dp = JsonParser(dp)
241247

torchvision/prototype/datasets/_builtin/imagenet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import functools
12
import io
23
import pathlib
34
import re
@@ -165,7 +166,7 @@ def _make_datapipe(
165166
dp = hint_shuffling(dp)
166167
dp = Mapper(dp, self._collate_test_data)
167168

168-
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
169+
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))
169170

170171
# Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849
171172
# and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment

torchvision/prototype/datasets/_builtin/mnist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def _make_datapipe(
136136
dp = Zipper(images_dp, labels_dp)
137137
dp = hint_sharding(dp)
138138
dp = hint_shuffling(dp)
139-
return Mapper(dp, self._collate_and_decode, fn_kwargs=dict(config=config, decoder=decoder))
139+
return Mapper(dp, functools.partial(self._collate_and_decode, config=config, decoder=decoder))
140140

141141

142142
class MNIST(_MNISTBase):

torchvision/prototype/datasets/_builtin/sbd.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import functools
12
import io
23
import pathlib
34
import re
@@ -152,7 +153,7 @@ def _make_datapipe(
152153
ref_key_fn=path_accessor("stem"),
153154
buffer_size=INFINITE_BUFFER_SIZE,
154155
)
155-
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(config=config, decoder=decoder))
156+
return Mapper(dp, functools.partial(self._collate_and_decode_sample, config=config, decoder=decoder))
156157

157158
def _generate_categories(self, root: pathlib.Path) -> Tuple[str, ...]:
158159
dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name)

torchvision/prototype/datasets/_builtin/semeion.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import functools
12
import io
23
from typing import Any, Callable, Dict, List, Optional, Tuple
34

@@ -65,5 +66,5 @@ def _make_datapipe(
6566
dp = CSVParser(dp, delimiter=" ")
6667
dp = hint_sharding(dp)
6768
dp = hint_shuffling(dp)
68-
dp = Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
69+
dp = Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))
6970
return dp

torchvision/prototype/datasets/_builtin/voc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def _make_datapipe(
127127
buffer_size=INFINITE_BUFFER_SIZE,
128128
)
129129

130-
split_dp = Filter(split_dp, self._is_in_folder, fn_kwargs=dict(name=self._SPLIT_FOLDER[config.task]))
130+
split_dp = Filter(split_dp, functools.partial(self._is_in_folder, name=self._SPLIT_FOLDER[config.task]))
131131
split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt"))
132132
split_dp = LineReader(split_dp, decode=True)
133133
split_dp = hint_sharding(split_dp)
@@ -142,4 +142,4 @@ def _make_datapipe(
142142
ref_key_fn=path_accessor("stem"),
143143
buffer_size=INFINITE_BUFFER_SIZE,
144144
)
145-
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(config=config, decoder=decoder))
145+
return Mapper(dp, functools.partial(self._collate_and_decode_sample, config=config, decoder=decoder))

torchvision/prototype/datasets/_folder.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import functools
12
import io
23
import os
34
import os.path
@@ -50,12 +51,12 @@ def from_data_folder(
5051
categories = sorted(entry.name for entry in os.scandir(root) if entry.is_dir())
5152
masks: Union[List[str], str] = [f"*.{ext}" for ext in valid_extensions] if valid_extensions is not None else ""
5253
dp = FileLister(str(root), recursive=recursive, masks=masks)
53-
dp: IterDataPipe = Filter(dp, _is_not_top_level_file, fn_kwargs=dict(root=root))
54+
dp: IterDataPipe = Filter(dp, functools.partial(_is_not_top_level_file, root=root))
5455
dp = hint_sharding(dp)
5556
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
5657
dp = FileLoader(dp)
5758
return (
58-
Mapper(dp, _collate_and_decode_data, fn_kwargs=dict(root=root, categories=categories, decoder=decoder)),
59+
Mapper(dp, functools.partial(_collate_and_decode_data, root=root, categories=categories, decoder=decoder)),
5960
categories,
6061
)
6162

0 commit comments

Comments
 (0)