diff --git a/test/test_datasets.py b/test/test_datasets.py index bee781d488d..26064a11c71 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -1620,6 +1620,10 @@ def inject_fake_data(self, tmpdir, config): num_examples_total += num_examples classes.append(cls) + if config.pop("make_empty_class", False): + os.makedirs(pathlib.Path(tmpdir) / "empty_class") + classes.append("empty_class") + return dict(num_examples=num_examples_total, classes=classes) def _file_name_fn(self, cls, ext, idx): @@ -1644,6 +1648,23 @@ def test_classes(self, config): assert len(dataset.classes) == len(info["classes"]) assert all([a == b for a, b in zip(dataset.classes, info["classes"])]) + def test_allow_empty(self): + config = { + "extensions": self._EXTENSIONS, + "make_empty_class": True, + } + + config["allow_empty"] = True + with self.create_dataset(config) as (dataset, info): + assert "empty_class" in dataset.classes + assert len(dataset.classes) == len(info["classes"]) + assert all([a == b for a, b in zip(dataset.classes, info["classes"])]) + + config["allow_empty"] = False + with pytest.raises(FileNotFoundError, match="Found no valid file"): + with self.create_dataset(config) as (dataset, info): + pass + class ImageFolderTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.ImageFolder diff --git a/torchvision/datasets/folder.py b/torchvision/datasets/folder.py index 49519233810..3a38d556287 100644 --- a/torchvision/datasets/folder.py +++ b/torchvision/datasets/folder.py @@ -50,6 +50,7 @@ def make_dataset( class_to_idx: Optional[Dict[str, int]] = None, extensions: Optional[Union[str, Tuple[str, ...]]] = None, is_valid_file: Optional[Callable[[str], bool]] = None, + allow_empty: bool = False, ) -> List[Tuple[str, int]]: """Generates a list of samples of a form (path_to_sample, class). @@ -95,7 +96,7 @@ def is_valid_file(x: str) -> bool: available_classes.add(target_class) empty_classes = set(class_to_idx.keys()) - available_classes - if empty_classes: + if empty_classes and not allow_empty: msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. " if extensions is not None: msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}" @@ -123,6 +124,8 @@ class DatasetFolder(VisionDataset): is_valid_file (callable, optional): A function that takes path of a file and check if the file is a valid file (used to check of corrupt files) both extensions and is_valid_file should not be passed. + allow_empty(bool, optional): If True, empty folders are considered to be valid classes. + An error is raised on empty folders if False (default). Attributes: classes (list): List of the class names sorted alphabetically. @@ -139,10 +142,17 @@ def __init__( transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, is_valid_file: Optional[Callable[[str], bool]] = None, + allow_empty: bool = False, ) -> None: super().__init__(root, transform=transform, target_transform=target_transform) classes, class_to_idx = self.find_classes(self.root) - samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file) + samples = self.make_dataset( + self.root, + class_to_idx=class_to_idx, + extensions=extensions, + is_valid_file=is_valid_file, + allow_empty=allow_empty, + ) self.loader = loader self.extensions = extensions @@ -158,6 +168,7 @@ def make_dataset( class_to_idx: Dict[str, int], extensions: Optional[Tuple[str, ...]] = None, is_valid_file: Optional[Callable[[str], bool]] = None, + allow_empty: bool = False, ) -> List[Tuple[str, int]]: """Generates a list of samples of a form (path_to_sample, class). @@ -172,6 +183,8 @@ def make_dataset( and checks if the file is a valid file (used to check of corrupt files) both extensions and is_valid_file should not be passed. Defaults to None. + allow_empty(bool, optional): If True, empty folders are considered to be valid classes. + An error is raised on empty folders if False (default). Raises: ValueError: In case ``class_to_idx`` is empty. @@ -186,7 +199,9 @@ def make_dataset( # find_classes() function, instead of using that of the find_classes() method, which # is potentially overridden and thus could have a different logic. raise ValueError("The class_to_idx parameter cannot be None.") - return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file) + return make_dataset( + directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file, allow_empty=allow_empty + ) def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]: """Find the class folders in a dataset structured as follows:: @@ -291,6 +306,8 @@ class ImageFolder(DatasetFolder): loader (callable, optional): A function to load an image given its path. is_valid_file (callable, optional): A function that takes path of an Image file and check if the file is a valid file (used to check of corrupt files) + allow_empty(bool, optional): If True, empty folders are considered to be valid classes. + An error is raised on empty folders if False (default). Attributes: classes (list): List of the class names sorted alphabetically. @@ -305,6 +322,7 @@ def __init__( target_transform: Optional[Callable] = None, loader: Callable[[str], Any] = default_loader, is_valid_file: Optional[Callable[[str], bool]] = None, + allow_empty: bool = False, ): super().__init__( root, @@ -313,5 +331,6 @@ def __init__( transform=transform, target_transform=target_transform, is_valid_file=is_valid_file, + allow_empty=allow_empty, ) self.imgs = self.samples