Skip to content

Commit f6d33bb

Browse files
author
Caroline Chen
committed
Remove pandas dependecy for CelebA dataset
1 parent 07fb8ba commit f6d33bb

File tree

3 files changed

+40
-18
lines changed

3 files changed

+40
-18
lines changed

test/datasets_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ class LazyImporter:
5353
MODULES = (
5454
"av",
5555
"lmdb",
56-
"pandas",
5756
"pycocotools",
5857
"requests",
5958
"scipy.io",

test/test_datasets.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -616,7 +616,6 @@ class CelebATestCase(datasets_utils.ImageDatasetTestCase):
616616
split=("train", "valid", "test", "all"),
617617
target_type=("attr", "identity", "bbox", "landmarks", ["attr", "identity"]),
618618
)
619-
REQUIRED_PACKAGES = ("pandas",)
620619

621620
_SPLIT_TO_IDX = dict(train=0, valid=1, test=2)
622621

torchvision/datasets/celeba.py

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
1+
from collections import namedtuple
2+
import csv
13
from functools import partial
24
import torch
35
import os
6+
import numpy as np
47
import PIL
58
from typing import Any, Callable, List, Optional, Union, Tuple
69
from .vision import VisionDataset
710
from .utils import download_file_from_google_drive, check_integrity, verify_str_arg
811

12+
CSV = namedtuple("CSV", ["header", "index", "data"])
913

1014
class CelebA(VisionDataset):
1115
"""`Large-scale CelebFaces Attributes (CelebA) Dataset <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`_ Dataset.
@@ -61,7 +65,6 @@ def __init__(
6165
target_transform: Optional[Callable] = None,
6266
download: bool = False,
6367
) -> None:
64-
import pandas
6568
super(CelebA, self).__init__(root, transform=transform,
6669
target_transform=target_transform)
6770
self.split = split
@@ -88,23 +91,44 @@ def __init__(
8891
}
8992
split_ = split_map[verify_str_arg(split.lower(), "split",
9093
("train", "valid", "test", "all"))]
94+
splits = self._load_csv("list_eval_partition.txt", header=None, index_col=0)
95+
identity = self._load_csv("identity_CelebA.txt", header=None, index_col=0)
96+
bbox = self._load_csv("list_bbox_celeba.txt", header=1, index_col=0)
97+
landmarks_align = self._load_csv("list_landmarks_align_celeba.txt", header=1, index_col=0)
98+
attr = self._load_csv("list_attr_celeba.txt", header=1, index_col=0)
99+
100+
mask = slice(None) if split_ is None else (splits.data == split_).squeeze()
101+
102+
self.filename = splits.index
103+
self.identity = identity.data[mask]
104+
self.bbox = bbox.data[mask]
105+
self.landmarks_align = landmarks_align.data[mask]
106+
self.attr = attr.data[mask]
107+
self.attr = (self.attr + 1) // 2 # map from {-1, 1} to {0, 1}
108+
self.attr_names = attr.header
109+
110+
def _load_csv(
111+
self,
112+
filename: str,
113+
header: int = None,
114+
index_col: int = None
115+
) -> CSV:
116+
data, indices, headers = [], [], []
91117

92118
fn = partial(os.path.join, self.root, self.base_folder)
93-
splits = pandas.read_csv(fn("list_eval_partition.txt"), delim_whitespace=True, header=None, index_col=0)
94-
identity = pandas.read_csv(fn("identity_CelebA.txt"), delim_whitespace=True, header=None, index_col=0)
95-
bbox = pandas.read_csv(fn("list_bbox_celeba.txt"), delim_whitespace=True, header=1, index_col=0)
96-
landmarks_align = pandas.read_csv(fn("list_landmarks_align_celeba.txt"), delim_whitespace=True, header=1)
97-
attr = pandas.read_csv(fn("list_attr_celeba.txt"), delim_whitespace=True, header=1)
98-
99-
mask = slice(None) if split_ is None else (splits[1] == split_)
100-
101-
self.filename = splits[mask].index.values
102-
self.identity = torch.as_tensor(identity[mask].values)
103-
self.bbox = torch.as_tensor(bbox[mask].values)
104-
self.landmarks_align = torch.as_tensor(landmarks_align[mask].values)
105-
self.attr = torch.as_tensor(attr[mask].values)
106-
self.attr = (self.attr + 1) // 2 # map from {-1, 1} to {0, 1}
107-
self.attr_names = list(attr.columns)
119+
with open(fn(filename)) as csv_file:
120+
data = list(csv.reader(csv_file, delimiter=' ', skipinitialspace=True))
121+
122+
if header is not None:
123+
headers = data[header]
124+
data = data[header + 1:]
125+
data_np = np.array(data)
126+
127+
if index_col is not None:
128+
indices = data_np[:, index_col]
129+
data_np = np.delete(data_np, index_col, axis=1)
130+
131+
return CSV(headers, indices, torch.as_tensor(data_np.astype(int)))
108132

109133
def _check_integrity(self) -> bool:
110134
for (_, md5, filename) in self.file_list:

0 commit comments

Comments
 (0)