1
+ from collections import namedtuple
2
+ import csv
1
3
from functools import partial
2
4
import torch
3
5
import os
6
+ import numpy as np
4
7
import PIL
5
8
from typing import Any , Callable , List , Optional , Union , Tuple
6
9
from .vision import VisionDataset
7
10
from .utils import download_file_from_google_drive , check_integrity , verify_str_arg
8
11
12
+ CSV = namedtuple ("CSV" , ["header" , "index" , "data" ])
9
13
10
14
class CelebA (VisionDataset ):
11
15
"""`Large-scale CelebFaces Attributes (CelebA) Dataset <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`_ Dataset.
@@ -61,7 +65,6 @@ def __init__(
61
65
target_transform : Optional [Callable ] = None ,
62
66
download : bool = False ,
63
67
) -> None :
64
- import pandas
65
68
super (CelebA , self ).__init__ (root , transform = transform ,
66
69
target_transform = target_transform )
67
70
self .split = split
@@ -88,23 +91,44 @@ def __init__(
88
91
}
89
92
split_ = split_map [verify_str_arg (split .lower (), "split" ,
90
93
("train" , "valid" , "test" , "all" ))]
94
+ splits = self ._load_csv ("list_eval_partition.txt" , header = None , index_col = 0 )
95
+ identity = self ._load_csv ("identity_CelebA.txt" , header = None , index_col = 0 )
96
+ bbox = self ._load_csv ("list_bbox_celeba.txt" , header = 1 , index_col = 0 )
97
+ landmarks_align = self ._load_csv ("list_landmarks_align_celeba.txt" , header = 1 , index_col = 0 )
98
+ attr = self ._load_csv ("list_attr_celeba.txt" , header = 1 , index_col = 0 )
99
+
100
+ mask = slice (None ) if split_ is None else (splits .data == split_ ).squeeze ()
101
+
102
+ self .filename = splits .index
103
+ self .identity = identity .data [mask ]
104
+ self .bbox = bbox .data [mask ]
105
+ self .landmarks_align = landmarks_align .data [mask ]
106
+ self .attr = attr .data [mask ]
107
+ self .attr = (self .attr + 1 ) // 2 # map from {-1, 1} to {0, 1}
108
+ self .attr_names = attr .header
109
+
110
+ def _load_csv (
111
+ self ,
112
+ filename : str ,
113
+ header : int = None ,
114
+ index_col : int = None
115
+ ) -> CSV :
116
+ data , indices , headers = [], [], []
91
117
92
118
fn = partial (os .path .join , self .root , self .base_folder )
93
- splits = pandas .read_csv (fn ("list_eval_partition.txt" ), delim_whitespace = True , header = None , index_col = 0 )
94
- identity = pandas .read_csv (fn ("identity_CelebA.txt" ), delim_whitespace = True , header = None , index_col = 0 )
95
- bbox = pandas .read_csv (fn ("list_bbox_celeba.txt" ), delim_whitespace = True , header = 1 , index_col = 0 )
96
- landmarks_align = pandas .read_csv (fn ("list_landmarks_align_celeba.txt" ), delim_whitespace = True , header = 1 )
97
- attr = pandas .read_csv (fn ("list_attr_celeba.txt" ), delim_whitespace = True , header = 1 )
98
-
99
- mask = slice (None ) if split_ is None else (splits [1 ] == split_ )
100
-
101
- self .filename = splits [mask ].index .values
102
- self .identity = torch .as_tensor (identity [mask ].values )
103
- self .bbox = torch .as_tensor (bbox [mask ].values )
104
- self .landmarks_align = torch .as_tensor (landmarks_align [mask ].values )
105
- self .attr = torch .as_tensor (attr [mask ].values )
106
- self .attr = (self .attr + 1 ) // 2 # map from {-1, 1} to {0, 1}
107
- self .attr_names = list (attr .columns )
119
+ with open (fn (filename )) as csv_file :
120
+ data = list (csv .reader (csv_file , delimiter = ' ' , skipinitialspace = True ))
121
+
122
+ if header is not None :
123
+ headers = data [header ]
124
+ data = data [header + 1 :]
125
+ data_np = np .array (data )
126
+
127
+ if index_col is not None :
128
+ indices = data_np [:, index_col ]
129
+ data_np = np .delete (data_np , index_col , axis = 1 )
130
+
131
+ return CSV (headers , indices , torch .as_tensor (data_np .astype (int )))
108
132
109
133
def _check_integrity (self ) -> bool :
110
134
for (_ , md5 , filename ) in self .file_list :
0 commit comments