From 72557c21dcd32d2513de4583244cb7807cfb740e Mon Sep 17 00:00:00 2001 From: Caroline Chen Date: Fri, 9 Apr 2021 12:01:04 -0700 Subject: [PATCH 1/3] Remove pandas dependecy for CelebA dataset --- test/datasets_utils.py | 1 - test/test_datasets.py | 1 - torchvision/datasets/celeba.py | 57 ++++++++++++++++++++++++---------- 3 files changed, 41 insertions(+), 18 deletions(-) diff --git a/test/datasets_utils.py b/test/datasets_utils.py index 658ef6640fe..60e3990f3a2 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -53,7 +53,6 @@ class LazyImporter: MODULES = ( "av", "lmdb", - "pandas", "pycocotools", "requests", "scipy.io", diff --git a/test/test_datasets.py b/test/test_datasets.py index db80b55a90f..0c4cbfeac70 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -616,7 +616,6 @@ class CelebATestCase(datasets_utils.ImageDatasetTestCase): split=("train", "valid", "test", "all"), target_type=("attr", "identity", "bbox", "landmarks", ["attr", "identity"]), ) - REQUIRED_PACKAGES = ("pandas",) _SPLIT_TO_IDX = dict(train=0, valid=1, test=2) diff --git a/torchvision/datasets/celeba.py b/torchvision/datasets/celeba.py index cc6f8084a80..71dad6d7662 100644 --- a/torchvision/datasets/celeba.py +++ b/torchvision/datasets/celeba.py @@ -1,11 +1,16 @@ +from collections import namedtuple +import csv from functools import partial import torch import os +import numpy as np import PIL from typing import Any, Callable, List, Optional, Union, Tuple from .vision import VisionDataset from .utils import download_file_from_google_drive, check_integrity, verify_str_arg +CSV = namedtuple("CSV", ["header", "index", "data"]) + class CelebA(VisionDataset): """`Large-scale CelebFaces Attributes (CelebA) Dataset `_ Dataset. @@ -61,7 +66,6 @@ def __init__( target_transform: Optional[Callable] = None, download: bool = False, ) -> None: - import pandas super(CelebA, self).__init__(root, transform=transform, target_transform=target_transform) self.split = split @@ -88,23 +92,44 @@ def __init__( } split_ = split_map[verify_str_arg(split.lower(), "split", ("train", "valid", "test", "all"))] + splits = self._load_csv("list_eval_partition.txt", header=None, index_col=0) + identity = self._load_csv("identity_CelebA.txt", header=None, index_col=0) + bbox = self._load_csv("list_bbox_celeba.txt", header=1, index_col=0) + landmarks_align = self._load_csv("list_landmarks_align_celeba.txt", header=1, index_col=0) + attr = self._load_csv("list_attr_celeba.txt", header=1, index_col=0) + + mask = slice(None) if split_ is None else (splits.data == split_).squeeze() + + self.filename = splits.index + self.identity = identity.data[mask] + self.bbox = bbox.data[mask] + self.landmarks_align = landmarks_align.data[mask] + self.attr = attr.data[mask] + self.attr = (self.attr + 1) // 2 # map from {-1, 1} to {0, 1} + self.attr_names = attr.header + + def _load_csv( + self, + filename: str, + header: int = None, + index_col: int = None + ) -> CSV: + data, indices, headers = [], [], [] fn = partial(os.path.join, self.root, self.base_folder) - splits = pandas.read_csv(fn("list_eval_partition.txt"), delim_whitespace=True, header=None, index_col=0) - identity = pandas.read_csv(fn("identity_CelebA.txt"), delim_whitespace=True, header=None, index_col=0) - bbox = pandas.read_csv(fn("list_bbox_celeba.txt"), delim_whitespace=True, header=1, index_col=0) - landmarks_align = pandas.read_csv(fn("list_landmarks_align_celeba.txt"), delim_whitespace=True, header=1) - attr = pandas.read_csv(fn("list_attr_celeba.txt"), delim_whitespace=True, header=1) - - mask = slice(None) if split_ is None else (splits[1] == split_) - - self.filename = splits[mask].index.values - self.identity = torch.as_tensor(identity[mask].values) - self.bbox = torch.as_tensor(bbox[mask].values) - self.landmarks_align = torch.as_tensor(landmarks_align[mask].values) - self.attr = torch.as_tensor(attr[mask].values) - self.attr = (self.attr + 1) // 2 # map from {-1, 1} to {0, 1} - self.attr_names = list(attr.columns) + with open(fn(filename)) as csv_file: + data = list(csv.reader(csv_file, delimiter=' ', skipinitialspace=True)) + + if header is not None: + headers = data[header] + data = data[header + 1:] + data_np = np.array(data) + + if index_col is not None: + indices = data_np[:, index_col] + data_np = np.delete(data_np, index_col, axis=1) + + return CSV(headers, indices, torch.as_tensor(data_np.astype(int))) def _check_integrity(self) -> bool: for (_, md5, filename) in self.file_list: From 6da2b815667989af2a6e60cbea3c144e952e024d Mon Sep 17 00:00:00 2001 From: Caroline Chen Date: Sun, 11 Apr 2021 22:37:14 -0700 Subject: [PATCH 2/3] address PR comments --- torchvision/datasets/celeba.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/torchvision/datasets/celeba.py b/torchvision/datasets/celeba.py index 71dad6d7662..0920105f238 100644 --- a/torchvision/datasets/celeba.py +++ b/torchvision/datasets/celeba.py @@ -3,7 +3,6 @@ from functools import partial import torch import os -import numpy as np import PIL from typing import Any, Callable, List, Optional, Union, Tuple from .vision import VisionDataset @@ -92,11 +91,11 @@ def __init__( } split_ = split_map[verify_str_arg(split.lower(), "split", ("train", "valid", "test", "all"))] - splits = self._load_csv("list_eval_partition.txt", header=None, index_col=0) - identity = self._load_csv("identity_CelebA.txt", header=None, index_col=0) - bbox = self._load_csv("list_bbox_celeba.txt", header=1, index_col=0) - landmarks_align = self._load_csv("list_landmarks_align_celeba.txt", header=1, index_col=0) - attr = self._load_csv("list_attr_celeba.txt", header=1, index_col=0) + splits = self._load_csv("list_eval_partition.txt") + identity = self._load_csv("identity_CelebA.txt") + bbox = self._load_csv("list_bbox_celeba.txt", header=1) + landmarks_align = self._load_csv("list_landmarks_align_celeba.txt", header=1) + attr = self._load_csv("list_attr_celeba.txt", header=1) mask = slice(None) if split_ is None else (splits.data == split_).squeeze() @@ -112,7 +111,6 @@ def _load_csv( self, filename: str, header: int = None, - index_col: int = None ) -> CSV: data, indices, headers = [], [], [] @@ -123,13 +121,12 @@ def _load_csv( if header is not None: headers = data[header] data = data[header + 1:] - data_np = np.array(data) - if index_col is not None: - indices = data_np[:, index_col] - data_np = np.delete(data_np, index_col, axis=1) + indices = [row[0] for row in data] + data = [row[1:] for row in data] + data_int = [list(map(int, i)) for i in data] - return CSV(headers, indices, torch.as_tensor(data_np.astype(int))) + return CSV(headers, indices, torch.tensor(data_int)) def _check_integrity(self) -> bool: for (_, md5, filename) in self.file_list: From f956996f386ea4de572108da74e53f9042ef2fdf Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 12 Apr 2021 09:38:53 +0000 Subject: [PATCH 3/3] Apply suggestions from code review Co-authored-by: Philip Meier --- torchvision/datasets/celeba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/datasets/celeba.py b/torchvision/datasets/celeba.py index 0920105f238..5c202da05b9 100644 --- a/torchvision/datasets/celeba.py +++ b/torchvision/datasets/celeba.py @@ -110,7 +110,7 @@ def __init__( def _load_csv( self, filename: str, - header: int = None, + header: Optional[int] = None, ) -> CSV: data, indices, headers = [], [], []