Skip to content

Commit a3572e4

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] call dataset wrapper with idx and sample (#7235)
Reviewed By: vmoens Differential Revision: D44416250 fbshipit-source-id: 713d63a132a5c57a175465384d461b0b7ff61983
1 parent e3b1d1e commit a3572e4

File tree

1 file changed

+20
-14
lines changed

1 file changed

+20
-14
lines changed

torchvision/prototype/datapoints/_dataset_wrapper.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def __getitem__(self, idx):
7474
# of this class
7575
sample = self._dataset[idx]
7676

77-
sample = self._wrapper(sample)
77+
sample = self._wrapper(idx, sample)
7878

7979
# Regardless of whether the user has supplied the transforms individually (`transform` and `target_transform`)
8080
# or joint (`transforms`), we can access the full functionality through `transforms`
@@ -125,7 +125,10 @@ def wrap_target_by_type(target, *, target_types, type_wrappers):
125125

126126

127127
def classification_wrapper_factory(dataset):
128-
return identity
128+
def wrapper(idx, sample):
129+
return sample
130+
131+
return wrapper
129132

130133

131134
for dataset_cls in [
@@ -143,7 +146,7 @@ def classification_wrapper_factory(dataset):
143146

144147

145148
def segmentation_wrapper_factory(dataset):
146-
def wrapper(sample):
149+
def wrapper(idx, sample):
147150
image, mask = sample
148151
return image, pil_image_to_mask(mask)
149152

@@ -163,7 +166,7 @@ def video_classification_wrapper_factory(dataset):
163166
f"since it is not compatible with the transformations. Please use `output_format='TCHW'` instead."
164167
)
165168

166-
def wrapper(sample):
169+
def wrapper(idx, sample):
167170
video, audio, label = sample
168171

169172
video = datapoints.Video(video)
@@ -201,14 +204,17 @@ def segmentation_to_mask(segmentation, *, spatial_size):
201204
)
202205
return torch.from_numpy(mask.decode(segmentation))
203206

204-
def wrapper(sample):
207+
def wrapper(idx, sample):
208+
image_id = dataset.ids[idx]
209+
205210
image, target = sample
206211

212+
if not target:
213+
return image, dict(image_id=image_id)
214+
207215
batched_target = list_of_dicts_to_dict_of_lists(target)
208216

209-
image_ids = batched_target.pop("image_id")
210-
image_id = batched_target["image_id"] = image_ids.pop()
211-
assert all(other_image_id == image_id for other_image_id in image_ids)
217+
batched_target["image_id"] = image_id
212218

213219
spatial_size = tuple(F.get_spatial_size(image))
214220
batched_target["boxes"] = datapoints.BoundingBox(
@@ -259,7 +265,7 @@ def wrapper(sample):
259265

260266
@WRAPPER_FACTORIES.register(datasets.VOCDetection)
261267
def voc_detection_wrapper_factory(dataset):
262-
def wrapper(sample):
268+
def wrapper(idx, sample):
263269
image, target = sample
264270

265271
batched_instances = list_of_dicts_to_dict_of_lists(target["annotation"]["object"])
@@ -294,7 +300,7 @@ def celeba_wrapper_factory(dataset):
294300
if any(target_type in dataset.target_type for target_type in ["attr", "landmarks"]):
295301
raise_not_supported("`CelebA` dataset with `target_type=['attr', 'landmarks', ...]`")
296302

297-
def wrapper(sample):
303+
def wrapper(idx, sample):
298304
image, target = sample
299305

300306
target = wrap_target_by_type(
@@ -318,7 +324,7 @@ def wrapper(sample):
318324

319325
@WRAPPER_FACTORIES.register(datasets.Kitti)
320326
def kitti_wrapper_factory(dataset):
321-
def wrapper(sample):
327+
def wrapper(idx, sample):
322328
image, target = sample
323329

324330
if target is not None:
@@ -336,7 +342,7 @@ def wrapper(sample):
336342

337343
@WRAPPER_FACTORIES.register(datasets.OxfordIIITPet)
338344
def oxford_iiit_pet_wrapper_factor(dataset):
339-
def wrapper(sample):
345+
def wrapper(idx, sample):
340346
image, target = sample
341347

342348
if target is not None:
@@ -371,7 +377,7 @@ def instance_segmentation_wrapper(mask):
371377
labels.append(label)
372378
return dict(masks=datapoints.Mask(torch.stack(masks)), labels=torch.stack(labels))
373379

374-
def wrapper(sample):
380+
def wrapper(idx, sample):
375381
image, target = sample
376382

377383
target = wrap_target_by_type(
@@ -390,7 +396,7 @@ def wrapper(sample):
390396

391397
@WRAPPER_FACTORIES.register(datasets.WIDERFace)
392398
def widerface_wrapper(dataset):
393-
def wrapper(sample):
399+
def wrapper(idx, sample):
394400
image, target = sample
395401

396402
if target is not None:

0 commit comments

Comments
 (0)