diff --git a/torchvision/datasets/celeba.py b/torchvision/datasets/celeba.py index b3c5f9e9a6a..e9dd883b92e 100644 --- a/torchvision/datasets/celeba.py +++ b/torchvision/datasets/celeba.py @@ -1,13 +1,12 @@ import csv import os -import warnings from collections import namedtuple from typing import Any, Callable, List, Optional, Union, Tuple import PIL import torch -from .utils import check_integrity, verify_str_arg +from .utils import download_file_from_google_drive, check_integrity, verify_str_arg, extract_archive from .vision import VisionDataset CSV = namedtuple("CSV", ["header", "index", "data"]) @@ -36,17 +35,9 @@ class CelebA(VisionDataset): and returns a transformed version. E.g, ``transforms.PILToTensor`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. - download (bool, optional): Deprecated. - - .. warning:: - - Downloading CelebA is not supported anymore as of 0.13 and this - parameter will be removed in 0.15. See - `this issue `__ - for more details. - Please download the files from - https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html and extract - them in ``root/celeba``. + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. """ base_folder = "celeba" @@ -73,7 +64,7 @@ def __init__( target_type: Union[List[str], str] = "attr", transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, - download: bool = None, + download: bool = False, ) -> None: super().__init__(root, transform=transform, target_transform=target_transform) self.split = split @@ -85,15 +76,6 @@ def __init__( if not self.target_type and self.target_transform is not None: raise RuntimeError("target_transform is specified but target_type is empty") - if download is not None: - warnings.warn( - "Downloading CelebA is not supported anymore as of 0.13, and the " - "download parameter will be removed in 0.15. See " - "https://github.com/pytorch/vision/issues/5705 for more details. " - "Please download the files from " - "https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html and extract them " - "in ``root/celeba``." - ) if download: self.download() @@ -164,14 +146,10 @@ def download(self) -> None: print("Files already downloaded and verified") return - raise ValueError( - "Downloading CelebA is not supported anymore as of 0.13, and the " - "download parameter will be removed in 0.15. See " - "https://github.com/pytorch/vision/issues/5705 for more details. " - "Please download the files from " - "https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html and extract them " - "in ``root/celeba``." - ) + for (file_id, md5, filename) in self.file_list: + download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5) + + extract_archive(os.path.join(self.root, self.base_folder, "img_align_celeba.zip")) def __getitem__(self, index: int) -> Tuple[Any, Any]: X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))