@@ -51,13 +51,6 @@ class CIFAR10(data.Dataset):
51
51
'md5' : '5ff9c542aee3614f3951f8cda6e48888' ,
52
52
}
53
53
54
- @property
55
- def targets (self ):
56
- if self .train :
57
- return self .train_labels
58
- else :
59
- return self .test_labels
60
-
61
54
def __init__ (self , root , train = True ,
62
55
transform = None , target_transform = None ,
63
56
download = False ):
@@ -73,44 +66,30 @@ def __init__(self, root, train=True,
73
66
raise RuntimeError ('Dataset not found or corrupted.' +
74
67
' You can use download=True to download it' )
75
68
76
- # now load the picked numpy arrays
77
69
if self .train :
78
- self .train_data = []
79
- self .train_labels = []
80
- for fentry in self .train_list :
81
- f = fentry [0 ]
82
- file = os .path .join (self .root , self .base_folder , f )
83
- fo = open (file , 'rb' )
70
+ downloaded_list = self .train_list
71
+ else :
72
+ downloaded_list = self .test_list
73
+
74
+ self .data = []
75
+ self .targets = []
76
+
77
+ # now load the picked numpy arrays
78
+ for file_name , checksum in downloaded_list :
79
+ file_path = os .path .join (self .root , self .base_folder , file_name )
80
+ with open (file_path , 'rb' ) as f :
84
81
if sys .version_info [0 ] == 2 :
85
- entry = pickle .load (fo )
82
+ entry = pickle .load (f )
86
83
else :
87
- entry = pickle .load (fo , encoding = 'latin1' )
88
- self .train_data .append (entry ['data' ])
84
+ entry = pickle .load (f , encoding = 'latin1' )
85
+ self .data .append (entry ['data' ])
89
86
if 'labels' in entry :
90
- self .train_labels += entry ['labels' ]
87
+ self .targets . extend ( entry ['labels' ])
91
88
else :
92
- self .train_labels += entry ['fine_labels' ]
93
- fo .close ()
89
+ self .targets .extend (entry ['fine_labels' ])
94
90
95
- self .train_data = np .concatenate (self .train_data )
96
- self .train_data = self .train_data .reshape ((50000 , 3 , 32 , 32 ))
97
- self .train_data = self .train_data .transpose ((0 , 2 , 3 , 1 )) # convert to HWC
98
- else :
99
- f = self .test_list [0 ][0 ]
100
- file = os .path .join (self .root , self .base_folder , f )
101
- fo = open (file , 'rb' )
102
- if sys .version_info [0 ] == 2 :
103
- entry = pickle .load (fo )
104
- else :
105
- entry = pickle .load (fo , encoding = 'latin1' )
106
- self .test_data = entry ['data' ]
107
- if 'labels' in entry :
108
- self .test_labels = entry ['labels' ]
109
- else :
110
- self .test_labels = entry ['fine_labels' ]
111
- fo .close ()
112
- self .test_data = self .test_data .reshape ((10000 , 3 , 32 , 32 ))
113
- self .test_data = self .test_data .transpose ((0 , 2 , 3 , 1 )) # convert to HWC
91
+ self .data = np .vstack (self .data ).reshape (- 1 , 3 , 32 , 32 )
92
+ self .data = self .data .transpose ((0 , 2 , 3 , 1 )) # convert to HWC
114
93
115
94
self ._load_meta ()
116
95
@@ -135,10 +114,7 @@ def __getitem__(self, index):
135
114
Returns:
136
115
tuple: (image, target) where target is index of the target class.
137
116
"""
138
- if self .train :
139
- img , target = self .train_data [index ], self .train_labels [index ]
140
- else :
141
- img , target = self .test_data [index ], self .test_labels [index ]
117
+ img , target = self .data [index ], self .targets [index ]
142
118
143
119
# doing this so that it is consistent with all other datasets
144
120
# to return a PIL Image
@@ -153,10 +129,7 @@ def __getitem__(self, index):
153
129
return img , target
154
130
155
131
def __len__ (self ):
156
- if self .train :
157
- return len (self .train_data )
158
- else :
159
- return len (self .test_data )
132
+ return len (self .data )
160
133
161
134
def _check_integrity (self ):
162
135
root = self .root
@@ -174,16 +147,11 @@ def download(self):
174
147
print ('Files already downloaded and verified' )
175
148
return
176
149
177
- root = self .root
178
- download_url (self .url , root , self .filename , self .tgz_md5 )
150
+ download_url (self .url , self .root , self .filename , self .tgz_md5 )
179
151
180
152
# extract file
181
- cwd = os .getcwd ()
182
- tar = tarfile .open (os .path .join (root , self .filename ), "r:gz" )
183
- os .chdir (root )
184
- tar .extractall ()
185
- tar .close ()
186
- os .chdir (cwd )
153
+ with tarfile .open (os .path .join (self .root , self .filename ), "r:gz" ) as tar :
154
+ tar .extractall (path = self .root )
187
155
188
156
def __repr__ (self ):
189
157
fmt_str = 'Dataset ' + self .__class__ .__name__ + '\n '
0 commit comments