Skip to content

Commit d1ab583

Browse files
authored
Add support for files with periods in name (#4099)
1 parent 850491e commit d1ab583

File tree

2 files changed

+26
-37
lines changed

2 files changed

+26
-37
lines changed

test/test_datasets_utils.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ def test_detect_file_type(self):
6363
("foo.gz", (".gz", None, ".gz")),
6464
("foo.zip", (".zip", ".zip", None)),
6565
("foo.xz", (".xz", None, ".xz")),
66+
("foo.bar.tar.gz", (".tar.gz", ".tar", ".gz")),
67+
("foo.bar.gz", (".gz", None, ".gz")),
68+
("foo.bar.zip", (".zip", ".zip", None)),
6669
]:
6770
with self.subTest(file=file):
6871
self.assertSequenceEqual(utils._detect_file_type(file), expected)
@@ -71,14 +74,6 @@ def test_detect_file_type_no_ext(self):
7174
with self.assertRaises(RuntimeError):
7275
utils._detect_file_type("foo")
7376

74-
def test_detect_file_type_to_many_exts(self):
75-
with self.assertRaises(RuntimeError):
76-
utils._detect_file_type("foo.bar.tar.gz")
77-
78-
def test_detect_file_type_unknown_archive_type(self):
79-
with self.assertRaises(RuntimeError):
80-
utils._detect_file_type("foo.bar.gz")
81-
8277
def test_detect_file_type_unknown_compression(self):
8378
with self.assertRaises(RuntimeError):
8479
utils._detect_file_type("foo.tar.baz")

torchvision/datasets/utils.py

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -291,53 +291,47 @@ def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> No
291291
}
292292

293293

294-
def _verify_archive_type(archive_type: str) -> None:
295-
if archive_type not in _ARCHIVE_EXTRACTORS.keys():
296-
valid_types = "', '".join(_ARCHIVE_EXTRACTORS.keys())
297-
raise RuntimeError(f"Unknown archive type '{archive_type}'. Known archive types are '{valid_types}'.")
298-
294+
def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]:
295+
"""Detect the archive type and/or compression of a file.
299296
300-
def _verify_compression(compression: str) -> None:
301-
if compression not in _COMPRESSED_FILE_OPENERS.keys():
302-
valid_types = "', '".join(_COMPRESSED_FILE_OPENERS.keys())
303-
raise RuntimeError(f"Unknown compression '{compression}'. Known compressions are '{valid_types}'.")
297+
Args:
298+
file (str): the filename
304299
300+
Returns:
301+
(tuple): tuple of suffix, archive type, and compression
305302
306-
def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]:
307-
path = pathlib.Path(file)
308-
suffix = path.suffix
303+
Raises:
304+
RuntimeError: if file has no suffix or suffix is not supported
305+
"""
309306
suffixes = pathlib.Path(file).suffixes
310307
if not suffixes:
311308
raise RuntimeError(
312309
f"File '{file}' has no suffixes that could be used to detect the archive type and compression."
313310
)
314-
elif len(suffixes) > 2:
315-
raise RuntimeError(
316-
"Archive type and compression detection only works for 1 or 2 suffixes. " f"Got {len(suffixes)} instead."
317-
)
318-
elif len(suffixes) == 2:
319-
# if we have exactly two suffixes we assume the first one is the archive type and the second on is the
320-
# compression
321-
archive_type, compression = suffixes
322-
_verify_archive_type(archive_type)
323-
_verify_compression(compression)
324-
return "".join(suffixes), archive_type, compression
311+
suffix = suffixes[-1]
325312

326313
# check if the suffix is a known alias
327-
with contextlib.suppress(KeyError):
314+
if suffix in _FILE_TYPE_ALIASES:
328315
return (suffix, *_FILE_TYPE_ALIASES[suffix])
329316

330317
# check if the suffix is an archive type
331-
with contextlib.suppress(RuntimeError):
332-
_verify_archive_type(suffix)
318+
if suffix in _ARCHIVE_EXTRACTORS:
333319
return suffix, suffix, None
334320

335321
# check if the suffix is a compression
336-
with contextlib.suppress(RuntimeError):
337-
_verify_compression(suffix)
322+
if suffix in _COMPRESSED_FILE_OPENERS:
323+
# check for suffix hierarchy
324+
if len(suffixes) > 1:
325+
suffix2 = suffixes[-2]
326+
327+
# check if the suffix2 is an archive type
328+
if suffix2 in _ARCHIVE_EXTRACTORS:
329+
return suffix2 + suffix, suffix2, suffix
330+
338331
return suffix, None, suffix
339332

340-
raise RuntimeError(f"Suffix '{suffix}' is neither recognized as archive type nor as compression.")
333+
valid_suffixes = sorted(set(_FILE_TYPE_ALIASES) | set(_ARCHIVE_EXTRACTORS) | set(_COMPRESSED_FILE_OPENERS))
334+
raise RuntimeError(f"Unknown compression or archive type: '{suffix}'.\nKnown suffixes are: '{valid_suffixes}'.")
341335

342336

343337
def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:

0 commit comments

Comments
 (0)