Skip to content

Commit fc7911c

Browse files
dizczafmassa
authored andcommitted
CIFAR: permanent 'data' and 'targets' fields (#594)
1 parent f3d5e85 commit fc7911c

File tree

1 file changed

+23
-55
lines changed

1 file changed

+23
-55
lines changed

torchvision/datasets/cifar.py

Lines changed: 23 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,6 @@ class CIFAR10(data.Dataset):
5151
'md5': '5ff9c542aee3614f3951f8cda6e48888',
5252
}
5353

54-
@property
55-
def targets(self):
56-
if self.train:
57-
return self.train_labels
58-
else:
59-
return self.test_labels
60-
6154
def __init__(self, root, train=True,
6255
transform=None, target_transform=None,
6356
download=False):
@@ -73,44 +66,30 @@ def __init__(self, root, train=True,
7366
raise RuntimeError('Dataset not found or corrupted.' +
7467
' You can use download=True to download it')
7568

76-
# now load the picked numpy arrays
7769
if self.train:
78-
self.train_data = []
79-
self.train_labels = []
80-
for fentry in self.train_list:
81-
f = fentry[0]
82-
file = os.path.join(self.root, self.base_folder, f)
83-
fo = open(file, 'rb')
70+
downloaded_list = self.train_list
71+
else:
72+
downloaded_list = self.test_list
73+
74+
self.data = []
75+
self.targets = []
76+
77+
# now load the picked numpy arrays
78+
for file_name, checksum in downloaded_list:
79+
file_path = os.path.join(self.root, self.base_folder, file_name)
80+
with open(file_path, 'rb') as f:
8481
if sys.version_info[0] == 2:
85-
entry = pickle.load(fo)
82+
entry = pickle.load(f)
8683
else:
87-
entry = pickle.load(fo, encoding='latin1')
88-
self.train_data.append(entry['data'])
84+
entry = pickle.load(f, encoding='latin1')
85+
self.data.append(entry['data'])
8986
if 'labels' in entry:
90-
self.train_labels += entry['labels']
87+
self.targets.extend(entry['labels'])
9188
else:
92-
self.train_labels += entry['fine_labels']
93-
fo.close()
89+
self.targets.extend(entry['fine_labels'])
9490

95-
self.train_data = np.concatenate(self.train_data)
96-
self.train_data = self.train_data.reshape((50000, 3, 32, 32))
97-
self.train_data = self.train_data.transpose((0, 2, 3, 1)) # convert to HWC
98-
else:
99-
f = self.test_list[0][0]
100-
file = os.path.join(self.root, self.base_folder, f)
101-
fo = open(file, 'rb')
102-
if sys.version_info[0] == 2:
103-
entry = pickle.load(fo)
104-
else:
105-
entry = pickle.load(fo, encoding='latin1')
106-
self.test_data = entry['data']
107-
if 'labels' in entry:
108-
self.test_labels = entry['labels']
109-
else:
110-
self.test_labels = entry['fine_labels']
111-
fo.close()
112-
self.test_data = self.test_data.reshape((10000, 3, 32, 32))
113-
self.test_data = self.test_data.transpose((0, 2, 3, 1)) # convert to HWC
91+
self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
92+
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
11493

11594
self._load_meta()
11695

@@ -135,10 +114,7 @@ def __getitem__(self, index):
135114
Returns:
136115
tuple: (image, target) where target is index of the target class.
137116
"""
138-
if self.train:
139-
img, target = self.train_data[index], self.train_labels[index]
140-
else:
141-
img, target = self.test_data[index], self.test_labels[index]
117+
img, target = self.data[index], self.targets[index]
142118

143119
# doing this so that it is consistent with all other datasets
144120
# to return a PIL Image
@@ -153,10 +129,7 @@ def __getitem__(self, index):
153129
return img, target
154130

155131
def __len__(self):
156-
if self.train:
157-
return len(self.train_data)
158-
else:
159-
return len(self.test_data)
132+
return len(self.data)
160133

161134
def _check_integrity(self):
162135
root = self.root
@@ -174,16 +147,11 @@ def download(self):
174147
print('Files already downloaded and verified')
175148
return
176149

177-
root = self.root
178-
download_url(self.url, root, self.filename, self.tgz_md5)
150+
download_url(self.url, self.root, self.filename, self.tgz_md5)
179151

180152
# extract file
181-
cwd = os.getcwd()
182-
tar = tarfile.open(os.path.join(root, self.filename), "r:gz")
183-
os.chdir(root)
184-
tar.extractall()
185-
tar.close()
186-
os.chdir(cwd)
153+
with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar:
154+
tar.extractall(path=self.root)
187155

188156
def __repr__(self):
189157
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'

0 commit comments

Comments
 (0)