From 853613405068eeb3dcd00c7002179fcbe0632ae8 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 13 Feb 2023 20:17:07 +0100 Subject: [PATCH 1/3] call dataset wrapper with idx and sample --- .../prototype/datapoints/_dataset_wrapper.py | 28 +++++++++++-------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/torchvision/prototype/datapoints/_dataset_wrapper.py b/torchvision/prototype/datapoints/_dataset_wrapper.py index e60d61e5f90..1f21c7691aa 100644 --- a/torchvision/prototype/datapoints/_dataset_wrapper.py +++ b/torchvision/prototype/datapoints/_dataset_wrapper.py @@ -74,7 +74,7 @@ def __getitem__(self, idx): # of this class sample = self._dataset[idx] - sample = self._wrapper(sample) + sample = self._wrapper(idx, sample) # Regardless of whether the user has supplied the transforms individually (`transform` and `target_transform`) # or joint (`transforms`), we can access the full functionality through `transforms` @@ -125,7 +125,10 @@ def wrap_target_by_type(target, *, target_types, type_wrappers): def classification_wrapper_factory(dataset): - return identity + def wrapper(idx, sample): + return sample + + return wrapper for dataset_cls in [ @@ -143,7 +146,7 @@ def classification_wrapper_factory(dataset): def segmentation_wrapper_factory(dataset): - def wrapper(sample): + def wrapper(idx, sample): image, mask = sample return image, pil_image_to_mask(mask) @@ -163,7 +166,7 @@ def video_classification_wrapper_factory(dataset): f"since it is not compatible with the transformations. Please use `output_format='TCHW'` instead." ) - def wrapper(sample): + def wrapper(idx, sample): video, audio, label = sample video = datapoints.Video(video) @@ -201,9 +204,12 @@ def segmentation_to_mask(segmentation, *, spatial_size): ) return torch.from_numpy(mask.decode(segmentation)) - def wrapper(sample): + def wrapper(idx, sample): image, target = sample + if not target: + return image, dict(image_id=dataset.ids[idx]) + batched_target = list_of_dicts_to_dict_of_lists(target) image_ids = batched_target.pop("image_id") @@ -259,7 +265,7 @@ def wrapper(sample): @WRAPPER_FACTORIES.register(datasets.VOCDetection) def voc_detection_wrapper_factory(dataset): - def wrapper(sample): + def wrapper(idx, sample): image, target = sample batched_instances = list_of_dicts_to_dict_of_lists(target["annotation"]["object"]) @@ -294,7 +300,7 @@ def celeba_wrapper_factory(dataset): if any(target_type in dataset.target_type for target_type in ["attr", "landmarks"]): raise_not_supported("`CelebA` dataset with `target_type=['attr', 'landmarks', ...]`") - def wrapper(sample): + def wrapper(idx, sample): image, target = sample target = wrap_target_by_type( @@ -318,7 +324,7 @@ def wrapper(sample): @WRAPPER_FACTORIES.register(datasets.Kitti) def kitti_wrapper_factory(dataset): - def wrapper(sample): + def wrapper(idx, sample): image, target = sample if target is not None: @@ -336,7 +342,7 @@ def wrapper(sample): @WRAPPER_FACTORIES.register(datasets.OxfordIIITPet) def oxford_iiit_pet_wrapper_factor(dataset): - def wrapper(sample): + def wrapper(idx, sample): image, target = sample if target is not None: @@ -371,7 +377,7 @@ def instance_segmentation_wrapper(mask): labels.append(label) return dict(masks=datapoints.Mask(torch.stack(masks)), labels=torch.stack(labels)) - def wrapper(sample): + def wrapper(idx, sample): image, target = sample target = wrap_target_by_type( @@ -390,7 +396,7 @@ def wrapper(sample): @WRAPPER_FACTORIES.register(datasets.WIDERFace) def widerface_wrapper(dataset): - def wrapper(sample): + def wrapper(idx, sample): image, target = sample if target is not None: From b9974cbbde1525c7aee3b14512fa04a7690a5fd4 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 13 Feb 2023 20:42:55 +0100 Subject: [PATCH 2/3] refactor image_id retrieval --- torchvision/prototype/datapoints/_dataset_wrapper.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/torchvision/prototype/datapoints/_dataset_wrapper.py b/torchvision/prototype/datapoints/_dataset_wrapper.py index 1f21c7691aa..922bdd64926 100644 --- a/torchvision/prototype/datapoints/_dataset_wrapper.py +++ b/torchvision/prototype/datapoints/_dataset_wrapper.py @@ -205,16 +205,17 @@ def segmentation_to_mask(segmentation, *, spatial_size): return torch.from_numpy(mask.decode(segmentation)) def wrapper(idx, sample): + image_id = dataset.ids[idx] + image, target = sample if not target: - return image, dict(image_id=dataset.ids[idx]) + return image, dict(image_id=image_id) batched_target = list_of_dicts_to_dict_of_lists(target) - image_ids = batched_target.pop("image_id") - image_id = batched_target["image_id"] = image_ids.pop() - assert all(other_image_id == image_id for other_image_id in image_ids) + assert all(image_id_from_target == image_id for image_id_from_target in batched_target.pop("image_id")) + batched_target["image_id"] = image_id spatial_size = tuple(F.get_spatial_size(image)) batched_target["boxes"] = datapoints.BoundingBox( From 8b32391bc0b153c556d35a0c7a810e8b254380f3 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 13 Feb 2023 20:47:07 +0100 Subject: [PATCH 3/3] remove image id check --- torchvision/prototype/datapoints/_dataset_wrapper.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchvision/prototype/datapoints/_dataset_wrapper.py b/torchvision/prototype/datapoints/_dataset_wrapper.py index 922bdd64926..dc4578c49f4 100644 --- a/torchvision/prototype/datapoints/_dataset_wrapper.py +++ b/torchvision/prototype/datapoints/_dataset_wrapper.py @@ -214,7 +214,6 @@ def wrapper(idx, sample): batched_target = list_of_dicts_to_dict_of_lists(target) - assert all(image_id_from_target == image_id for image_id_from_target in batched_target.pop("image_id")) batched_target["image_id"] = image_id spatial_size = tuple(F.get_spatial_size(image))