Skip to content

Commit a15ff20

Browse files
ArdalanMfmassa
authored andcommitted
add tar.xz archive handler (#1361)
* add tar.xz archive handler * update unittest for tar.xz archive * remove .tar.xz unittest * add separate .tar.xz unittest * update PY2 compatibility
1 parent 1909495 commit a15ff20

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

test/test_datasets_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,23 @@ def test_extract_tar(self):
100100
data = nf.read()
101101
self.assertEqual(data, 'this is the content')
102102

103+
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
104+
@unittest.skipIf(sys.version_info < (3,), "Extracting .tar.xz files is not supported under Python 2.x")
105+
def test_extract_tar_xz(self):
106+
for ext, mode in zip(['.tar.xz'], ['w:xz']):
107+
with get_tmp_dir() as temp_dir:
108+
with tempfile.NamedTemporaryFile() as bf:
109+
bf.write("this is the content".encode())
110+
bf.seek(0)
111+
with tempfile.NamedTemporaryFile(suffix=ext) as f:
112+
with tarfile.open(f.name, mode=mode) as zf:
113+
zf.add(bf.name, arcname='file.tst')
114+
utils.extract_archive(f.name, temp_dir)
115+
self.assertTrue(os.path.exists(os.path.join(temp_dir, 'file.tst')))
116+
with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf:
117+
data = nf.read()
118+
self.assertEqual(data, 'this is the content')
119+
103120
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
104121
def test_extract_gzip(self):
105122
with get_tmp_dir() as temp_dir:

torchvision/datasets/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import torch
1010
from torch.utils.model_zoo import tqdm
11+
from torch._six import PY3
1112

1213

1314
def gen_bar_updater():
@@ -197,6 +198,10 @@ def _save_response_content(response, destination, chunk_size=32768):
197198
pbar.close()
198199

199200

201+
def _is_tarxz(filename):
202+
return filename.endswith(".tar.xz")
203+
204+
200205
def _is_tar(filename):
201206
return filename.endswith(".tar")
202207

@@ -223,6 +228,10 @@ def extract_archive(from_path, to_path=None, remove_finished=False):
223228
elif _is_targz(from_path):
224229
with tarfile.open(from_path, 'r:gz') as tar:
225230
tar.extractall(path=to_path)
231+
elif _is_tarxz(from_path) and PY3:
232+
# .tar.xz archive only supported in Python 3.x
233+
with tarfile.open(from_path, 'r:xz') as tar:
234+
tar.extractall(path=to_path)
226235
elif _is_gzip(from_path):
227236
to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0])
228237
with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f:

0 commit comments

Comments
 (0)