Skip to content

Commit 53e1330

Browse files
YosuaMichaelfacebook-github-bot
authored andcommitted
[fbsync] simplify OnlineResource.load (#5990)
Summary: * simplify OnlineResource.load * [PoC] merge mock data preparation and loading * Revert "cache mock data based on config" This reverts commit 5ed6eedef74865e0baa746a375d5ec1f0ab1bde7. * Revert "[PoC] merge mock data preparation and loading" This reverts commit d627479. * remove preprocess returning a new path in favor of querying twice * address test comments * clarify comment * mypy * use builtin decompress utility Reviewed By: NicolasHug Differential Revision: D36760923 fbshipit-source-id: 1d3d30be96c3226fc181c4654208b2d3c6fdf7cb
1 parent 38207bf commit 53e1330

File tree

2 files changed

+217
-32
lines changed

2 files changed

+217
-32
lines changed

test/test_prototype_datasets_utils.py

Lines changed: 183 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
1+
import gzip
2+
import pathlib
13
import sys
24

35
import numpy as np
46
import pytest
57
import torch
6-
from datasets_utils import make_fake_flo_file
8+
from datasets_utils import make_fake_flo_file, make_tar
9+
from torchdata.datapipes.iter import FileOpener, TarArchiveLoader
710
from torchvision.datasets._optical_flow import _read_flo as read_flo_ref
8-
from torchvision.prototype.datasets.utils import HttpResource, GDriveResource, Dataset
11+
from torchvision.datasets.utils import _decompress
12+
from torchvision.prototype.datasets.utils import HttpResource, GDriveResource, Dataset, OnlineResource
913
from torchvision.prototype.datasets.utils._internal import read_flo, fromfile
1014

1115

@@ -48,6 +52,183 @@ def test_read_flo(tmpdir):
4852
torch.testing.assert_close(actual, expected)
4953

5054

55+
class TestOnlineResource:
56+
class DummyResource(OnlineResource):
57+
def __init__(self, download_fn=None, **kwargs):
58+
super().__init__(**kwargs)
59+
self._download_fn = download_fn
60+
61+
def _download(self, root):
62+
if self._download_fn is None:
63+
raise pytest.UsageError(
64+
"`_download()` was called, but `DummyResource(...)` was constructed without `download_fn`."
65+
)
66+
67+
return self._download_fn(self, root)
68+
69+
def _make_file(self, root, *, content, name="file.txt"):
70+
file = root / name
71+
with open(file, "w") as fh:
72+
fh.write(content)
73+
74+
return file
75+
76+
def _make_folder(self, root, *, name="folder"):
77+
folder = root / name
78+
subfolder = folder / "subfolder"
79+
subfolder.mkdir(parents=True)
80+
81+
files = {}
82+
for idx, root in enumerate([folder, folder, subfolder]):
83+
content = f"sentinel{idx}"
84+
file = self._make_file(root, name=f"file{idx}.txt", content=content)
85+
files[str(file)] = content
86+
87+
return folder, files
88+
89+
def _make_tar(self, root, *, name="archive.tar", remove=True):
90+
folder, files = self._make_folder(root, name=name.split(".")[0])
91+
archive = make_tar(root, name, folder, remove=remove)
92+
files = {str(archive / pathlib.Path(file).relative_to(root)): content for file, content in files.items()}
93+
return archive, files
94+
95+
def test_load_file(self, tmp_path):
96+
content = "sentinel"
97+
file = self._make_file(tmp_path, content=content)
98+
99+
resource = self.DummyResource(file_name=file.name)
100+
101+
dp = resource.load(tmp_path)
102+
assert isinstance(dp, FileOpener)
103+
104+
data = list(dp)
105+
assert len(data) == 1
106+
107+
path, buffer = data[0]
108+
assert path == str(file)
109+
assert buffer.read().decode() == content
110+
111+
def test_load_folder(self, tmp_path):
112+
folder, files = self._make_folder(tmp_path)
113+
114+
resource = self.DummyResource(file_name=folder.name)
115+
116+
dp = resource.load(tmp_path)
117+
assert isinstance(dp, FileOpener)
118+
assert {path: buffer.read().decode() for path, buffer in dp} == files
119+
120+
def test_load_archive(self, tmp_path):
121+
archive, files = self._make_tar(tmp_path)
122+
123+
resource = self.DummyResource(file_name=archive.name)
124+
125+
dp = resource.load(tmp_path)
126+
assert isinstance(dp, TarArchiveLoader)
127+
assert {path: buffer.read().decode() for path, buffer in dp} == files
128+
129+
def test_priority_decompressed_gt_raw(self, tmp_path):
130+
# We don't need to actually compress here. Adding the suffix is sufficient
131+
self._make_file(tmp_path, content="raw_sentinel", name="file.txt.gz")
132+
file = self._make_file(tmp_path, content="decompressed_sentinel", name="file.txt")
133+
134+
resource = self.DummyResource(file_name=file.name)
135+
136+
dp = resource.load(tmp_path)
137+
path, buffer = next(iter(dp))
138+
139+
assert path == str(file)
140+
assert buffer.read().decode() == "decompressed_sentinel"
141+
142+
def test_priority_extracted_gt_decompressed(self, tmp_path):
143+
archive, _ = self._make_tar(tmp_path, remove=False)
144+
145+
resource = self.DummyResource(file_name=archive.name)
146+
147+
dp = resource.load(tmp_path)
148+
# If the archive had been selected, this would be a `TarArchiveReader`
149+
assert isinstance(dp, FileOpener)
150+
151+
def test_download(self, tmp_path):
152+
download_fn_was_called = False
153+
154+
def download_fn(resource, root):
155+
nonlocal download_fn_was_called
156+
download_fn_was_called = True
157+
158+
return self._make_file(root, content="_", name=resource.file_name)
159+
160+
resource = self.DummyResource(
161+
file_name="file.txt",
162+
download_fn=download_fn,
163+
)
164+
165+
resource.load(tmp_path)
166+
167+
assert download_fn_was_called, "`download_fn()` was never called"
168+
169+
# This tests the `"decompress"` literal as well as a custom callable
170+
@pytest.mark.parametrize(
171+
"preprocess",
172+
[
173+
"decompress",
174+
lambda path: _decompress(str(path), remove_finished=True),
175+
],
176+
)
177+
def test_preprocess_decompress(self, tmp_path, preprocess):
178+
file_name = "file.txt.gz"
179+
content = "sentinel"
180+
181+
def download_fn(resource, root):
182+
file = root / resource.file_name
183+
with gzip.open(file, "wb") as fh:
184+
fh.write(content.encode())
185+
return file
186+
187+
resource = self.DummyResource(file_name=file_name, preprocess=preprocess, download_fn=download_fn)
188+
189+
dp = resource.load(tmp_path)
190+
data = list(dp)
191+
assert len(data) == 1
192+
193+
path, buffer = data[0]
194+
assert path == str(tmp_path / file_name).replace(".gz", "")
195+
assert buffer.read().decode() == content
196+
197+
def test_preprocess_extract(self, tmp_path):
198+
files = None
199+
200+
def download_fn(resource, root):
201+
nonlocal files
202+
archive, files = self._make_tar(root, name=resource.file_name)
203+
return archive
204+
205+
resource = self.DummyResource(file_name="folder.tar", preprocess="extract", download_fn=download_fn)
206+
207+
dp = resource.load(tmp_path)
208+
assert files is not None, "`download_fn()` was never called"
209+
assert isinstance(dp, FileOpener)
210+
211+
actual = {path: buffer.read().decode() for path, buffer in dp}
212+
expected = {
213+
path.replace(resource.file_name, resource.file_name.split(".")[0]): content
214+
for path, content in files.items()
215+
}
216+
assert actual == expected
217+
218+
def test_preprocess_only_after_download(self, tmp_path):
219+
file = self._make_file(tmp_path, content="_")
220+
221+
def preprocess(path):
222+
raise AssertionError("`preprocess` was called although the file was already present.")
223+
224+
resource = self.DummyResource(
225+
file_name=file.name,
226+
preprocess=preprocess,
227+
)
228+
229+
resource.load(tmp_path)
230+
231+
51232
class TestHttpResource:
52233
def test_resolve_to_http(self, mocker):
53234
file_name = "data.tar"

torchvision/prototype/datasets/utils/_resource.py

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import hashlib
33
import itertools
44
import pathlib
5-
from typing import Optional, Sequence, Tuple, Callable, IO, Any, Union, NoReturn
5+
from typing import Optional, Sequence, Tuple, Callable, IO, Any, Union, NoReturn, Set
66
from urllib.parse import urlparse
77

88
from torchdata.datapipes.iter import (
@@ -32,7 +32,7 @@ def __init__(
3232
*,
3333
file_name: str,
3434
sha256: Optional[str] = None,
35-
preprocess: Optional[Union[Literal["decompress", "extract"], Callable[[pathlib.Path], pathlib.Path]]] = None,
35+
preprocess: Optional[Union[Literal["decompress", "extract"], Callable[[pathlib.Path], None]]] = None,
3636
) -> None:
3737
self.file_name = file_name
3838
self.sha256 = sha256
@@ -50,14 +50,12 @@ def __init__(
5050
self._preprocess = preprocess
5151

5252
@staticmethod
53-
def _extract(file: pathlib.Path) -> pathlib.Path:
54-
return pathlib.Path(
55-
extract_archive(str(file), to_path=str(file).replace("".join(file.suffixes), ""), remove_finished=False)
56-
)
53+
def _extract(file: pathlib.Path) -> None:
54+
extract_archive(str(file), to_path=str(file).replace("".join(file.suffixes), ""), remove_finished=False)
5755

5856
@staticmethod
59-
def _decompress(file: pathlib.Path) -> pathlib.Path:
60-
return pathlib.Path(_decompress(str(file), remove_finished=True))
57+
def _decompress(file: pathlib.Path) -> None:
58+
_decompress(str(file), remove_finished=True)
6159

6260
def _loader(self, path: pathlib.Path) -> IterDataPipe[Tuple[str, IO]]:
6361
if path.is_dir():
@@ -91,32 +89,38 @@ def load(
9189
) -> IterDataPipe[Tuple[str, IO]]:
9290
root = pathlib.Path(root)
9391
path = root / self.file_name
92+
9493
# Instead of the raw file, there might also be files with fewer suffixes after decompression or directories
95-
# with no suffixes at all.
94+
# with no suffixes at all. `pathlib.Path().stem` will only give us the name with the last suffix removed, which
95+
# is not sufficient for files with multiple suffixes, e.g. foo.tar.gz.
9696
stem = path.name.replace("".join(path.suffixes), "")
9797

98-
# In a first step, we check for a folder with the same stem as the raw file. If it exists, we use it since
99-
# extracted files give the best I/O performance. Note that OnlineResource._extract() makes sure that an archive
100-
# is always extracted in a folder with the corresponding file name.
101-
folder_candidate = path.parent / stem
102-
if folder_candidate.exists() and folder_candidate.is_dir():
103-
return self._loader(folder_candidate)
104-
105-
# If there is no folder, we look for all files that share the same stem as the raw file, but might have a
106-
# different suffix.
107-
file_candidates = {file for file in path.parent.glob(stem + ".*")}
108-
# If we don't find anything, we download the raw file.
109-
if not file_candidates:
110-
file_candidates = {self.download(root, skip_integrity_check=skip_integrity_check)}
111-
# If the only thing we find is the raw file, we use it and optionally perform some preprocessing steps.
112-
if file_candidates == {path}:
98+
def find_candidates() -> Set[pathlib.Path]:
99+
# Although it looks like we could glob for f"{stem}*" to find the file candidates as well as the folder
100+
# candidate simultaneously, that would also pick up other files that share the same prefix. For example, the
101+
# test split of the stanford-cars dataset uses the files
102+
# - cars_test.tgz
103+
# - cars_test_annos_withlabels.mat
104+
# Globbing for `"cars_test*"` picks up both.
105+
candidates = {file for file in path.parent.glob(f"{stem}.*")}
106+
folder_candidate = path.parent / stem
107+
if folder_candidate.exists():
108+
candidates.add(folder_candidate)
109+
110+
return candidates
111+
112+
candidates = find_candidates()
113+
114+
if not candidates:
115+
self.download(root, skip_integrity_check=skip_integrity_check)
113116
if self._preprocess is not None:
114-
path = self._preprocess(path)
115-
# Otherwise, we use the path with the fewest suffixes. This gives us the decompressed > raw priority that we
116-
# want for the best I/O performance.
117-
else:
118-
path = min(file_candidates, key=lambda path: len(path.suffixes))
119-
return self._loader(path)
117+
self._preprocess(path)
118+
candidates = find_candidates()
119+
120+
# We use the path with the fewest suffixes. This gives us the
121+
# extracted > decompressed > raw
122+
# priority that we want for the best I/O performance.
123+
return self._loader(min(candidates, key=lambda candidate: len(candidate.suffixes)))
120124

121125
@abc.abstractmethod
122126
def _download(self, root: pathlib.Path) -> None:

0 commit comments

Comments
 (0)