diff --git a/gallery/plot_transforms_v2_e2e.py b/gallery/plot_transforms_v2_e2e.py index aa25d214f31..5d8d22dce83 100644 --- a/gallery/plot_transforms_v2_e2e.py +++ b/gallery/plot_transforms_v2_e2e.py @@ -75,7 +75,8 @@ def load_example_coco_detection_dataset(**kwargs): # :func:`~torchvision.datasets.wrap_dataset_for_transforms_v2` function. For # :class:`~torchvision.datasets.CocoDetection`, this changes the target structure to a single dictionary of lists. It # also adds the key-value-pairs ``"boxes"``, ``"masks"``, and ``"labels"`` wrapped in the corresponding -# ``torchvision.datapoints``. +# ``torchvision.datapoints``. By default, it only returns ``"boxes"`` and ``"labels"`` to avoid transforming unnecessary +# items down the line, but you can pass the ``target_type`` parameter for fine-grained control. dataset = datasets.wrap_dataset_for_transforms_v2(dataset) @@ -83,7 +84,7 @@ def load_example_coco_detection_dataset(**kwargs): image, target = sample print(type(image)) print(type(target), list(target.keys())) -print(type(target["boxes"]), type(target["masks"]), type(target["labels"])) +print(type(target["boxes"]), type(target["labels"])) ######################################################################################################################## # As baseline, let's have a look at a sample without transformations: diff --git a/test/datasets_utils.py b/test/datasets_utils.py index 768324955b4..169437a7424 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -572,9 +572,21 @@ def test_transforms_v2_wrapper(self, config): try: with self.create_dataset(config) as (dataset, _): - wrapped_dataset = wrap_dataset_for_transforms_v2(dataset) - wrapped_sample = wrapped_dataset[0] - assert tree_any(lambda item: isinstance(item, (Datapoint, PIL.Image.Image)), wrapped_sample) + for target_keys in [None, "all"]: + if target_keys is not None and self.DATASET_CLASS not in { + torchvision.datasets.CocoDetection, + torchvision.datasets.VOCDetection, + torchvision.datasets.Kitti, + torchvision.datasets.WIDERFace, + }: + with self.assertRaisesRegex(ValueError, "`target_keys` is currently only supported for"): + wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys) + continue + + wrapped_dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys) + wrapped_sample = wrapped_dataset[0] + + assert tree_any(lambda item: isinstance(item, (Datapoint, PIL.Image.Image)), wrapped_sample) except TypeError as error: msg = f"No wrapper exists for dataset class {type(dataset).__name__}" if str(error).startswith(msg): diff --git a/test/test_datasets.py b/test/test_datasets.py index 48d08b846de..ed6aa17d3f9 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -771,6 +771,8 @@ def _create_annotations(self, image_ids, num_annotations_per_image): bbox=torch.rand(4).tolist(), segmentation=[torch.rand(8).tolist()], category_id=int(torch.randint(91, ())), + area=float(torch.rand(1)), + iscrowd=int(torch.randint(2, size=(1,))), ) ) annotion_id += 1 @@ -3336,7 +3338,7 @@ def test_subclass(self, mocker): mocker.patch.dict( datapoints._dataset_wrapper.WRAPPER_FACTORIES, clear=False, - values={datasets.FakeData: lambda dataset: lambda idx, sample: sentinel}, + values={datasets.FakeData: lambda dataset, target_keys: lambda idx, sample: sentinel}, ) class MyFakeData(datasets.FakeData): diff --git a/torchvision/datapoints/_dataset_wrapper.py b/torchvision/datapoints/_dataset_wrapper.py index 87ce3ba93a1..cce8f1b2e84 100644 --- a/torchvision/datapoints/_dataset_wrapper.py +++ b/torchvision/datapoints/_dataset_wrapper.py @@ -2,6 +2,8 @@ from __future__ import annotations +import collections.abc + import contextlib from collections import defaultdict @@ -14,7 +16,7 @@ __all__ = ["wrap_dataset_for_transforms_v2"] -def wrap_dataset_for_transforms_v2(dataset): +def wrap_dataset_for_transforms_v2(dataset, target_keys=None): """[BETA] Wrap a ``torchvision.dataset`` for usage with :mod:`torchvision.transforms.v2`. .. v2betastatus:: wrap_dataset_for_transforms_v2 function @@ -36,15 +38,17 @@ def wrap_dataset_for_transforms_v2(dataset): * :class:`~torchvision.datasets.CocoDetection`: Instead of returning the target as list of dicts, the wrapper returns a dict of lists. In addition, the key-value-pairs ``"boxes"`` (in ``XYXY`` coordinate format), ``"masks"`` and ``"labels"`` are added and wrap the data in the corresponding ``torchvision.datapoints``. - The original keys are preserved. + The original keys are preserved. If ``target_keys`` is ommitted, returns only the values for the ``"boxes"`` + and ``"labels"``. * :class:`~torchvision.datasets.VOCDetection`: The key-value-pairs ``"boxes"`` and ``"labels"`` are added to the target and wrap the data in the corresponding ``torchvision.datapoints``. The original keys are - preserved. + preserved. If ``target_keys`` is ommitted, returns only the values for the ``"boxes"`` and ``"labels"``. * :class:`~torchvision.datasets.CelebA`: The target for ``target_type="bbox"`` is converted to the ``XYXY`` coordinate format and wrapped into a :class:`~torchvision.datapoints.BoundingBox` datapoint. - * :class:`~torchvision.datasets.Kitti`: Instead returning the target as list of dictsthe wrapper returns a dict - of lists. In addition, the key-value-pairs ``"boxes"`` and ``"labels"`` are added and wrap the data - in the corresponding ``torchvision.datapoints``. The original keys are preserved. + * :class:`~torchvision.datasets.Kitti`: Instead returning the target as list of dicts, the wrapper returns a + dict of lists. In addition, the key-value-pairs ``"boxes"`` and ``"labels"`` are added and wrap the data + in the corresponding ``torchvision.datapoints``. The original keys are preserved. If ``target_keys`` is + ommitted, returns only the values for the ``"boxes"`` and ``"labels"``. * :class:`~torchvision.datasets.OxfordIIITPet`: The target for ``target_type="segmentation"`` is wrapped into a :class:`~torchvision.datapoints.Mask` datapoint. * :class:`~torchvision.datasets.Cityscapes`: The target for ``target_type="semantic"`` is wrapped into a @@ -61,13 +65,13 @@ def wrap_dataset_for_transforms_v2(dataset): Segmentation datasets - Segmentation datasets, e.g. :class:`~torchvision.datasets.VOCSegmentation` return a two-tuple of + Segmentation datasets, e.g. :class:`~torchvision.datasets.VOCSegmentation`, return a two-tuple of :class:`PIL.Image.Image`'s. This wrapper leaves the image as is (first item), while wrapping the segmentation mask into a :class:`~torchvision.datapoints.Mask` (second item). Video classification datasets - Video classification datasets, e.g. :class:`~torchvision.datasets.Kinetics` return a three-tuple containing a + Video classification datasets, e.g. :class:`~torchvision.datasets.Kinetics`, return a three-tuple containing a :class:`torch.Tensor` for the video and audio and a :class:`int` as label. This wrapper wraps the video into a :class:`~torchvision.datapoints.Video` while leaving the other items as is. @@ -78,8 +82,23 @@ def wrap_dataset_for_transforms_v2(dataset): Args: dataset: the dataset instance to wrap for compatibility with transforms v2. + target_keys: Target keys to return in case the target is a dictionary. If ``None`` (default), selected keys are + specific to the dataset. If ``"all"``, returns the full target. Can also be a collection of strings for + fine grained access. Currently only supported for :class:`~torchvision.datasets.CocoDetection`, + :class:`~torchvision.datasets.VOCDetection`, :class:`~torchvision.datasets.Kitti`, and + :class:`~torchvision.datasets.WIDERFace`. See above for details. """ - return VisionDatasetDatapointWrapper(dataset) + if not ( + target_keys is None + or target_keys == "all" + or (isinstance(target_keys, collections.abc.Collection) and all(isinstance(key, str) for key in target_keys)) + ): + raise ValueError( + f"`target_keys` can be None, 'all', or a collection of strings denoting the keys to be returned, " + f"but got {target_keys}" + ) + + return VisionDatasetDatapointWrapper(dataset, target_keys) class WrapperFactories(dict): @@ -99,7 +118,7 @@ def decorator(wrapper_factory): class VisionDatasetDatapointWrapper(Dataset): - def __init__(self, dataset): + def __init__(self, dataset, target_keys): dataset_cls = type(dataset) if not isinstance(dataset, datasets.VisionDataset): @@ -111,6 +130,16 @@ def __init__(self, dataset): for cls in dataset_cls.mro(): if cls in WRAPPER_FACTORIES: wrapper_factory = WRAPPER_FACTORIES[cls] + if target_keys is not None and cls not in { + datasets.CocoDetection, + datasets.VOCDetection, + datasets.Kitti, + datasets.WIDERFace, + }: + raise ValueError( + f"`target_keys` is currently only supported for `CocoDetection`, `VOCDetection`, `Kitti`, " + f"and `WIDERFace`, but got {cls.__name__}." + ) break elif cls is datasets.VisionDataset: # TODO: If we have documentation on how to do that, put a link in the error message. @@ -123,7 +152,7 @@ def __init__(self, dataset): raise TypeError(msg) self._dataset = dataset - self._wrapper = wrapper_factory(dataset) + self._wrapper = wrapper_factory(dataset, target_keys) # We need to disable the transforms on the dataset here to be able to inject the wrapping before we apply them. # Although internally, `datasets.VisionDataset` merges `transform` and `target_transform` into the joint @@ -170,7 +199,7 @@ def identity(item): return item -def identity_wrapper_factory(dataset): +def identity_wrapper_factory(dataset, target_keys): def wrapper(idx, sample): return sample @@ -181,6 +210,20 @@ def pil_image_to_mask(pil_image): return datapoints.Mask(pil_image) +def parse_target_keys(target_keys, *, available, default): + if target_keys is None: + target_keys = default + if target_keys == "all": + target_keys = available + else: + target_keys = set(target_keys) + extra = target_keys - available + if extra: + raise ValueError(f"Target keys {sorted(extra)} are not available") + + return target_keys + + def list_of_dicts_to_dict_of_lists(list_of_dicts): dict_of_lists = defaultdict(list) for dct in list_of_dicts: @@ -203,8 +246,8 @@ def wrap_target_by_type(target, *, target_types, type_wrappers): return wrapped_target -def classification_wrapper_factory(dataset): - return identity_wrapper_factory(dataset) +def classification_wrapper_factory(dataset, target_keys): + return identity_wrapper_factory(dataset, target_keys) for dataset_cls in [ @@ -221,7 +264,7 @@ def classification_wrapper_factory(dataset): WRAPPER_FACTORIES.register(dataset_cls)(classification_wrapper_factory) -def segmentation_wrapper_factory(dataset): +def segmentation_wrapper_factory(dataset, target_keys): def wrapper(idx, sample): image, mask = sample return image, pil_image_to_mask(mask) @@ -235,7 +278,7 @@ def wrapper(idx, sample): WRAPPER_FACTORIES.register(dataset_cls)(segmentation_wrapper_factory) -def video_classification_wrapper_factory(dataset): +def video_classification_wrapper_factory(dataset, target_keys): if dataset.video_clips.output_format == "THWC": raise RuntimeError( f"{type(dataset).__name__} with `output_format='THWC'` is not supported by this wrapper, " @@ -261,15 +304,33 @@ def wrapper(idx, sample): @WRAPPER_FACTORIES.register(datasets.Caltech101) -def caltech101_wrapper_factory(dataset): +def caltech101_wrapper_factory(dataset, target_keys): if "annotation" in dataset.target_type: raise_not_supported("Caltech101 dataset with `target_type=['annotation', ...]`") - return classification_wrapper_factory(dataset) + return classification_wrapper_factory(dataset, target_keys) @WRAPPER_FACTORIES.register(datasets.CocoDetection) -def coco_dectection_wrapper_factory(dataset): +def coco_dectection_wrapper_factory(dataset, target_keys): + target_keys = parse_target_keys( + target_keys, + available={ + # native + "segmentation", + "area", + "iscrowd", + "image_id", + "bbox", + "category_id", + # added by the wrapper + "boxes", + "masks", + "labels", + }, + default={"boxes", "labels"}, + ) + def segmentation_to_mask(segmentation, *, spatial_size): from pycocotools import mask @@ -288,30 +349,41 @@ def wrapper(idx, sample): if not target: return image, dict(image_id=image_id) + spatial_size = tuple(F.get_spatial_size(image)) + batched_target = list_of_dicts_to_dict_of_lists(target) + target = {} - batched_target["image_id"] = image_id + if "image_id" in target_keys: + target["image_id"] = image_id - spatial_size = tuple(F.get_spatial_size(image)) - batched_target["boxes"] = F.convert_format_bounding_box( - datapoints.BoundingBox( - batched_target["bbox"], - format=datapoints.BoundingBoxFormat.XYWH, - spatial_size=spatial_size, - ), - new_format=datapoints.BoundingBoxFormat.XYXY, - ) - batched_target["masks"] = datapoints.Mask( - torch.stack( - [ - segmentation_to_mask(segmentation, spatial_size=spatial_size) - for segmentation in batched_target["segmentation"] - ] - ), - ) - batched_target["labels"] = torch.tensor(batched_target["category_id"]) + if "boxes" in target_keys: + target["boxes"] = F.convert_format_bounding_box( + datapoints.BoundingBox( + batched_target["bbox"], + format=datapoints.BoundingBoxFormat.XYWH, + spatial_size=spatial_size, + ), + new_format=datapoints.BoundingBoxFormat.XYXY, + ) + + if "masks" in target_keys: + target["masks"] = datapoints.Mask( + torch.stack( + [ + segmentation_to_mask(segmentation, spatial_size=spatial_size) + for segmentation in batched_target["segmentation"] + ] + ), + ) + + if "labels" in target_keys: + target["labels"] = torch.tensor(batched_target["category_id"]) - return image, batched_target + for target_key in target_keys - {"image_id", "boxes", "masks", "labels"}: + target[target_key] = batched_target[target_key] + + return image, target return wrapper @@ -346,23 +418,41 @@ def wrapper(idx, sample): @WRAPPER_FACTORIES.register(datasets.VOCDetection) -def voc_detection_wrapper_factory(dataset): +def voc_detection_wrapper_factory(dataset, target_keys): + target_keys = parse_target_keys( + target_keys, + available={ + # native + "annotation", + # added by the wrapper + "boxes", + "labels", + }, + default={"boxes", "labels"}, + ) + def wrapper(idx, sample): image, target = sample batched_instances = list_of_dicts_to_dict_of_lists(target["annotation"]["object"]) - target["boxes"] = datapoints.BoundingBox( - [ - [int(bndbox[part]) for part in ("xmin", "ymin", "xmax", "ymax")] - for bndbox in batched_instances["bndbox"] - ], - format=datapoints.BoundingBoxFormat.XYXY, - spatial_size=(image.height, image.width), - ) - target["labels"] = torch.tensor( - [VOC_DETECTION_CATEGORY_TO_IDX[category] for category in batched_instances["name"]] - ) + if "annotation" not in target_keys: + target = {} + + if "boxes" in target_keys: + target["boxes"] = datapoints.BoundingBox( + [ + [int(bndbox[part]) for part in ("xmin", "ymin", "xmax", "ymax")] + for bndbox in batched_instances["bndbox"] + ], + format=datapoints.BoundingBoxFormat.XYXY, + spatial_size=(image.height, image.width), + ) + + if "labels" in target_keys: + target["labels"] = torch.tensor( + [VOC_DETECTION_CATEGORY_TO_IDX[category] for category in batched_instances["name"]] + ) return image, target @@ -370,15 +460,15 @@ def wrapper(idx, sample): @WRAPPER_FACTORIES.register(datasets.SBDataset) -def sbd_wrapper(dataset): +def sbd_wrapper(dataset, target_keys): if dataset.mode == "boundaries": raise_not_supported("SBDataset with mode='boundaries'") - return segmentation_wrapper_factory(dataset) + return segmentation_wrapper_factory(dataset, target_keys) @WRAPPER_FACTORIES.register(datasets.CelebA) -def celeba_wrapper_factory(dataset): +def celeba_wrapper_factory(dataset, target_keys): if any(target_type in dataset.target_type for target_type in ["attr", "landmarks"]): raise_not_supported("`CelebA` dataset with `target_type=['attr', 'landmarks', ...]`") @@ -410,17 +500,47 @@ def wrapper(idx, sample): @WRAPPER_FACTORIES.register(datasets.Kitti) -def kitti_wrapper_factory(dataset): +def kitti_wrapper_factory(dataset, target_keys): + target_keys = parse_target_keys( + target_keys, + available={ + # native + "type", + "truncated", + "occluded", + "alpha", + "bbox", + "dimensions", + "location", + "rotation_y", + # added by the wrapper + "boxes", + "labels", + }, + default={"boxes", "labels"}, + ) + def wrapper(idx, sample): image, target = sample - if target is not None: - target = list_of_dicts_to_dict_of_lists(target) + if target is None: + return image, target + + batched_target = list_of_dicts_to_dict_of_lists(target) + target = {} + if "boxes" in target_keys: target["boxes"] = datapoints.BoundingBox( - target["bbox"], format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(image.height, image.width) + batched_target["bbox"], + format=datapoints.BoundingBoxFormat.XYXY, + spatial_size=(image.height, image.width), ) - target["labels"] = torch.tensor([KITTI_CATEGORY_TO_IDX[category] for category in target["type"]]) + + if "labels" in target_keys: + target["labels"] = torch.tensor([KITTI_CATEGORY_TO_IDX[category] for category in batched_target["type"]]) + + for target_key in target_keys - {"boxes", "labels"}: + target[target_key] = batched_target[target_key] return image, target @@ -428,7 +548,7 @@ def wrapper(idx, sample): @WRAPPER_FACTORIES.register(datasets.OxfordIIITPet) -def oxford_iiit_pet_wrapper_factor(dataset): +def oxford_iiit_pet_wrapper_factor(dataset, target_keys): def wrapper(idx, sample): image, target = sample @@ -447,7 +567,7 @@ def wrapper(idx, sample): @WRAPPER_FACTORIES.register(datasets.Cityscapes) -def cityscapes_wrapper_factory(dataset): +def cityscapes_wrapper_factory(dataset, target_keys): if any(target_type in dataset.target_type for target_type in ["polygon", "color"]): raise_not_supported("`Cityscapes` dataset with `target_type=['polygon', 'color', ...]`") @@ -482,11 +602,30 @@ def wrapper(idx, sample): @WRAPPER_FACTORIES.register(datasets.WIDERFace) -def widerface_wrapper(dataset): +def widerface_wrapper(dataset, target_keys): + target_keys = parse_target_keys( + target_keys, + available={ + "bbox", + "blur", + "expression", + "illumination", + "occlusion", + "pose", + "invalid", + }, + default="all", + ) + def wrapper(idx, sample): image, target = sample - if target is not None: + if target is None: + return image, target + + target = {key: target[key] for key in target_keys} + + if "bbox" in target_keys: target["bbox"] = F.convert_format_bounding_box( datapoints.BoundingBox( target["bbox"], format=datapoints.BoundingBoxFormat.XYWH, spatial_size=(image.height, image.width)