diff --git a/test/datasets_utils.py b/test/datasets_utils.py index 598d4408b76..c02ffeb0d68 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -596,7 +596,7 @@ def test_transforms_v2_wrapper(self, config): wrapped_sample = wrapped_dataset[0] assert tree_any(lambda item: isinstance(item, (Datapoint, PIL.Image.Image)), wrapped_sample) except TypeError as error: - if str(error).startswith(f"No wrapper exist for dataset class {type(dataset).__name__}"): + if str(error).startswith(f"No wrapper exists for dataset class {type(dataset).__name__}"): return raise error except RuntimeError as error: diff --git a/test/test_prototype_datapoints.py b/test/test_prototype_datapoints.py index 4663cdac3da..c2cc0986b71 100644 --- a/test/test_prototype_datapoints.py +++ b/test/test_prototype_datapoints.py @@ -1,7 +1,11 @@ +import re + import pytest import torch from PIL import Image + +from torchvision import datasets from torchvision.prototype import datapoints @@ -159,3 +163,43 @@ def test_bbox_instance(data, format): if isinstance(format, str): format = datapoints.BoundingBoxFormat.from_str(format.upper()) assert bboxes.format == format + + +class TestDatasetWrapper: + def test_unknown_type(self): + unknown_object = object() + with pytest.raises( + TypeError, match=re.escape("is meant for subclasses of `torchvision.datasets.VisionDataset`") + ): + datapoints.wrap_dataset_for_transforms_v2(unknown_object) + + def test_unknown_dataset(self): + class MyVisionDataset(datasets.VisionDataset): + pass + + dataset = MyVisionDataset("root") + + with pytest.raises(TypeError, match="No wrapper exist"): + datapoints.wrap_dataset_for_transforms_v2(dataset) + + def test_missing_wrapper(self): + dataset = datasets.FakeData() + + with pytest.raises(TypeError, match="please open an issue"): + datapoints.wrap_dataset_for_transforms_v2(dataset) + + def test_subclass(self, mocker): + sentinel = object() + mocker.patch.dict( + datapoints._dataset_wrapper.WRAPPER_FACTORIES, + clear=False, + values={datasets.FakeData: lambda dataset: lambda idx, sample: sentinel}, + ) + + class MyFakeData(datasets.FakeData): + pass + + dataset = MyFakeData() + wrapped_dataset = datapoints.wrap_dataset_for_transforms_v2(dataset) + + assert wrapped_dataset[0] is sentinel diff --git a/torchvision/prototype/datapoints/_dataset_wrapper.py b/torchvision/prototype/datapoints/_dataset_wrapper.py index dc4578c49f4..74f83095177 100644 --- a/torchvision/prototype/datapoints/_dataset_wrapper.py +++ b/torchvision/prototype/datapoints/_dataset_wrapper.py @@ -39,16 +39,26 @@ def decorator(wrapper_factory): class VisionDatasetDatapointWrapper(Dataset): def __init__(self, dataset): dataset_cls = type(dataset) - wrapper_factory = WRAPPER_FACTORIES.get(dataset_cls) - if wrapper_factory is None: - # TODO: If we have documentation on how to do that, put a link in the error message. - msg = f"No wrapper exist for dataset class {dataset_cls.__name__}. Please wrap the output yourself." - if dataset_cls in datasets.__dict__.values(): - msg = ( - f"{msg} If an automated wrapper for this dataset would be useful for you, " - f"please open an issue at https://github.com/pytorch/vision/issues." - ) - raise TypeError(msg) + + if not isinstance(dataset, datasets.VisionDataset): + raise TypeError( + f"This wrapper is meant for subclasses of `torchvision.datasets.VisionDataset`, " + f"but got a '{dataset_cls.__name__}' instead." + ) + + for cls in dataset_cls.mro(): + if cls in WRAPPER_FACTORIES: + wrapper_factory = WRAPPER_FACTORIES[cls] + break + elif cls is datasets.VisionDataset: + # TODO: If we have documentation on how to do that, put a link in the error message. + msg = f"No wrapper exists for dataset class {dataset_cls.__name__}. Please wrap the output yourself." + if dataset_cls in datasets.__dict__.values(): + msg = ( + f"{msg} If an automated wrapper for this dataset would be useful for you, " + f"please open an issue at https://github.com/pytorch/vision/issues." + ) + raise TypeError(msg) self._dataset = dataset self._wrapper = wrapper_factory(dataset) @@ -98,6 +108,13 @@ def identity(item): return item +def identity_wrapper_factory(dataset): + def wrapper(idx, sample): + return sample + + return wrapper + + def pil_image_to_mask(pil_image): return datapoints.Mask(pil_image) @@ -125,10 +142,7 @@ def wrap_target_by_type(target, *, target_types, type_wrappers): def classification_wrapper_factory(dataset): - def wrapper(idx, sample): - return sample - - return wrapper + return identity_wrapper_factory(dataset) for dataset_cls in [ @@ -237,6 +251,9 @@ def wrapper(idx, sample): return wrapper +WRAPPER_FACTORIES.register(datasets.CocoCaptions)(identity_wrapper_factory) + + VOC_DETECTION_CATEGORIES = [ "__background__", "aeroplane",