diff --git a/test/test_datasets.py b/test/test_datasets.py new file mode 100644 index 00000000000..bc6474c96b8 --- /dev/null +++ b/test/test_datasets.py @@ -0,0 +1,38 @@ +import PIL +import shutil +import tempfile +import unittest + +import torchvision + + +class Tester(unittest.TestCase): + + def test_mnist(self): + tmp_dir = tempfile.mkdtemp() + dataset = torchvision.datasets.MNIST(tmp_dir, download=True) + self.assertEqual(len(dataset), 60000) + img, target = dataset[0] + self.assertTrue(isinstance(img, PIL.Image.Image)) + self.assertTrue(isinstance(target, int)) + shutil.rmtree(tmp_dir) + + def test_kmnist(self): + tmp_dir = tempfile.mkdtemp() + dataset = torchvision.datasets.KMNIST(tmp_dir, download=True) + img, target = dataset[0] + self.assertTrue(isinstance(img, PIL.Image.Image)) + self.assertTrue(isinstance(target, int)) + shutil.rmtree(tmp_dir) + + def test_fashionmnist(self): + tmp_dir = tempfile.mkdtemp() + dataset = torchvision.datasets.FashionMNIST(tmp_dir, download=True) + img, target = dataset[0] + self.assertTrue(isinstance(img, PIL.Image.Image)) + self.assertTrue(isinstance(target, int)) + shutil.rmtree(tmp_dir) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/test_datasets_utils.py b/test/test_datasets_utils.py index 1f48b88ac74..79a1e3992c2 100644 --- a/test/test_datasets_utils.py +++ b/test/test_datasets_utils.py @@ -3,6 +3,9 @@ import tempfile import torchvision.datasets.utils as utils import unittest +import zipfile +import tarfile +import gzip TEST_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg') @@ -41,6 +44,47 @@ def test_download_url_retry_http(self): assert not len(os.listdir(temp_dir)) == 0, 'The downloaded root directory is empty after download.' shutil.rmtree(temp_dir) + def test_extract_zip(self): + temp_dir = tempfile.mkdtemp() + with tempfile.NamedTemporaryFile(suffix='.zip') as f: + with zipfile.ZipFile(f, 'w') as zf: + zf.writestr('file.tst', 'this is the content') + utils.extract_file(f.name, temp_dir) + assert os.path.exists(os.path.join(temp_dir, 'file.tst')) + with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf: + data = nf.read() + assert data == 'this is the content' + shutil.rmtree(temp_dir) + + def test_extract_tar(self): + for ext, mode in zip(['.tar', '.tar.gz'], ['w', 'w:gz']): + temp_dir = tempfile.mkdtemp() + with tempfile.NamedTemporaryFile() as bf: + bf.write("this is the content".encode()) + bf.seek(0) + with tempfile.NamedTemporaryFile(suffix=ext) as f: + with tarfile.open(f.name, mode=mode) as zf: + zf.add(bf.name, arcname='file.tst') + utils.extract_file(f.name, temp_dir) + assert os.path.exists(os.path.join(temp_dir, 'file.tst')) + with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf: + data = nf.read() + assert data == 'this is the content', data + shutil.rmtree(temp_dir) + + def test_extract_gzip(self): + temp_dir = tempfile.mkdtemp() + with tempfile.NamedTemporaryFile(suffix='.gz') as f: + with gzip.GzipFile(f.name, 'wb') as zf: + zf.write('this is the content'.encode()) + utils.extract_file(f.name, temp_dir) + f_name = os.path.join(temp_dir, os.path.splitext(os.path.basename(f.name))[0]) + assert os.path.exists(f_name) + with open(os.path.join(f_name), 'r') as nf: + data = nf.read() + assert data == 'this is the content', data + shutil.rmtree(temp_dir) + if __name__ == '__main__': unittest.main() diff --git a/torchvision/datasets/caltech.py b/torchvision/datasets/caltech.py index 8c477e64810..43e8d7caf8d 100644 --- a/torchvision/datasets/caltech.py +++ b/torchvision/datasets/caltech.py @@ -4,7 +4,7 @@ import os.path from .vision import VisionDataset -from .utils import download_url, makedir_exist_ok +from .utils import download_and_extract, makedir_exist_ok class Caltech101(VisionDataset): @@ -109,27 +109,20 @@ def __len__(self): return len(self.index) def download(self): - import tarfile - if self._check_integrity(): print('Files already downloaded and verified') return - download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz", - self.root, - "101_ObjectCategories.tar.gz", - "b224c7392d521a49829488ab0f1120d9") - download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar", - self.root, - "101_Annotations.tar", - "6f83eeb1f24d99cab4eb377263132c91") - - # extract file - with tarfile.open(os.path.join(self.root, "101_ObjectCategories.tar.gz"), "r:gz") as tar: - tar.extractall(path=self.root) - - with tarfile.open(os.path.join(self.root, "101_Annotations.tar"), "r:") as tar: - tar.extractall(path=self.root) + download_and_extract( + "http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz", + self.root, + "101_ObjectCategories.tar.gz", + "b224c7392d521a49829488ab0f1120d9") + download_and_extract( + "http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar", + self.root, + "101_Annotations.tar", + "6f83eeb1f24d99cab4eb377263132c91") def extra_repr(self): return "Target type: {target_type}".format(**self.__dict__) @@ -204,17 +197,12 @@ def __len__(self): return len(self.index) def download(self): - import tarfile - if self._check_integrity(): print('Files already downloaded and verified') return - download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar", - self.root, - "256_ObjectCategories.tar", - "67b4f42ca05d46448c6bb8ecd2220f6d") - - # extract file - with tarfile.open(os.path.join(self.root, "256_ObjectCategories.tar"), "r:") as tar: - tar.extractall(path=self.root) + download_and_extract( + "http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar", + self.root, + "256_ObjectCategories.tar", + "67b4f42ca05d46448c6bb8ecd2220f6d") diff --git a/torchvision/datasets/cifar.py b/torchvision/datasets/cifar.py index 307e8f60a06..59ecda5cf4d 100644 --- a/torchvision/datasets/cifar.py +++ b/torchvision/datasets/cifar.py @@ -11,7 +11,7 @@ import pickle from .vision import VisionDataset -from .utils import download_url, check_integrity +from .utils import check_integrity, download_and_extract class CIFAR10(VisionDataset): @@ -144,17 +144,10 @@ def _check_integrity(self): return True def download(self): - import tarfile - if self._check_integrity(): print('Files already downloaded and verified') return - - download_url(self.url, self.root, self.filename, self.tgz_md5) - - # extract file - with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar: - tar.extractall(path=self.root) + download_and_extract(self.url, self.root, self.filename, self.tgz_md5) def extra_repr(self): return "Split: {}".format("Train" if self.train is True else "Test") diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py index e1a277d2de7..51a3cb19df4 100644 --- a/torchvision/datasets/mnist.py +++ b/torchvision/datasets/mnist.py @@ -4,11 +4,10 @@ from PIL import Image import os import os.path -import gzip import numpy as np import torch import codecs -from .utils import download_url, makedir_exist_ok +from .utils import download_and_extract, extract_file, makedir_exist_ok class MNIST(VisionDataset): @@ -120,15 +119,6 @@ def _check_exists(self): os.path.exists(os.path.join(self.processed_folder, self.test_file))) - @staticmethod - def extract_gzip(gzip_path, remove_finished=False): - print('Extracting {}'.format(gzip_path)) - with open(gzip_path.replace('.gz', ''), 'wb') as out_f, \ - gzip.GzipFile(gzip_path) as zip_f: - out_f.write(zip_f.read()) - if remove_finished: - os.unlink(gzip_path) - def download(self): """Download the MNIST data if it doesn't exist in processed_folder already.""" @@ -141,9 +131,7 @@ def download(self): # download files for url in self.urls: filename = url.rpartition('/')[2] - file_path = os.path.join(self.raw_folder, filename) - download_url(url, root=self.raw_folder, filename=filename, md5=None) - self.extract_gzip(gzip_path=file_path, remove_finished=True) + download_and_extract(url, root=self.raw_folder, filename=filename) # process and save as torch files print('Processing...') @@ -262,7 +250,6 @@ def _test_file(split): def download(self): """Download the EMNIST data if it doesn't exist in processed_folder already.""" import shutil - import zipfile if self._check_exists(): return @@ -271,18 +258,12 @@ def download(self): makedir_exist_ok(self.processed_folder) # download files - filename = self.url.rpartition('/')[2] - file_path = os.path.join(self.raw_folder, filename) - download_url(self.url, root=self.raw_folder, filename=filename, md5=None) - - print('Extracting zip archive') - with zipfile.ZipFile(file_path) as zip_f: - zip_f.extractall(self.raw_folder) - os.unlink(file_path) + print('Downloading and extracting zip archive') + download_and_extract(self.url, root=self.raw_folder, filename="emnist.zip", remove_finished=True) gzip_folder = os.path.join(self.raw_folder, 'gzip') for gzip_file in os.listdir(gzip_folder): if gzip_file.endswith('.gz'): - self.extract_gzip(gzip_path=os.path.join(gzip_folder, gzip_file)) + extract_file(os.path.join(gzip_folder, gzip_file), gzip_folder) # process and save as torch files for split in self.splits: diff --git a/torchvision/datasets/omniglot.py b/torchvision/datasets/omniglot.py index 9e2af0157a0..b5f6d64f12e 100644 --- a/torchvision/datasets/omniglot.py +++ b/torchvision/datasets/omniglot.py @@ -3,7 +3,7 @@ from os.path import join import os from .vision import VisionDataset -from .utils import download_url, check_integrity, list_dir, list_files +from .utils import download_and_extract, check_integrity, list_dir, list_files class Omniglot(VisionDataset): @@ -81,8 +81,6 @@ def _check_integrity(self): return True def download(self): - import zipfile - if self._check_integrity(): print('Files already downloaded and verified') return @@ -90,10 +88,7 @@ def download(self): filename = self._get_target_folder() zip_filename = filename + '.zip' url = self.download_url_prefix + '/' + zip_filename - download_url(url, self.root, zip_filename, self.zips_md5[filename]) - print('Extracting downloaded file: ' + join(self.root, zip_filename)) - with zipfile.ZipFile(join(self.root, zip_filename), 'r') as zip_file: - zip_file.extractall(self.root) + download_and_extract(url, self.root, zip_filename, self.zips_md5[filename]) def _get_target_folder(self): return 'images_background' if self.background else 'images_evaluation' diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index f0011603f57..86a2af48d52 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -1,7 +1,11 @@ import os import os.path import hashlib +import gzip import errno +import tarfile +import zipfile + from torch.utils.model_zoo import tqdm @@ -189,3 +193,46 @@ def _save_response_content(response, destination, chunk_size=32768): progress += len(chunk) pbar.update(progress - pbar.n) pbar.close() + + +def _is_tar(filename): + return filename.endswith(".tar") + + +def _is_targz(filename): + return filename.endswith(".tar.gz") + + +def _is_gzip(filename): + return filename.endswith(".gz") and not filename.endswith(".tar.gz") + + +def _is_zip(filename): + return filename.endswith(".zip") + + +def extract_file(from_path, to_path, remove_finished=False): + if _is_tar(from_path): + with tarfile.open(from_path, 'r:') as tar: + tar.extractall(path=to_path) + elif _is_targz(from_path): + with tarfile.open(from_path, 'r:gz') as tar: + tar.extractall(path=to_path) + elif _is_gzip(from_path): + to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0]) + with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f: + out_f.write(zip_f.read()) + elif _is_zip(from_path): + with zipfile.ZipFile(from_path, 'r') as z: + z.extractall(to_path) + else: + raise ValueError("Extraction of {} not supported".format(from_path)) + + if remove_finished: + os.unlink(from_path) + + +def download_and_extract(url, root, filename, md5=None, remove_finished=False): + download_url(url, root, filename, md5) + print("Extracting {} to {}".format(os.path.join(root, filename), root)) + extract_file(os.path.join(root, filename), root, remove_finished)