Skip to content

Commit d627479

Browse files
committed
[PoC] merge mock data preparation and loading
1 parent 5db94a8 commit d627479

File tree

2 files changed

+52
-55
lines changed

2 files changed

+52
-55
lines changed

test/builtin_dataset_mocks.py

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import pathlib
1111
import pickle
1212
import random
13+
import shutil
1314
import unittest.mock
1415
import warnings
1516
import xml.etree.ElementTree as ET
@@ -22,7 +23,6 @@
2223
from datasets_utils import make_zip, make_tar, create_image_folder, create_image_file, combinations_grid
2324
from torch.nn.functional import one_hot
2425
from torch.testing import make_tensor as _make_tensor
25-
from torchvision._utils import sequence_to_str
2626
from torchvision.prototype import datasets
2727

2828
make_tensor = functools.partial(_make_tensor, device="cpu")
@@ -62,27 +62,47 @@ def _parse_mock_info(self, mock_info):
6262

6363
return mock_info
6464

65-
def prepare(self, config):
65+
def load(self, config):
6666
# `datasets.home()` is patched to a temporary directory through the autouse fixture `test_home` in
6767
# test/test_prototype_builtin_datasets.py
6868
root = pathlib.Path(datasets.home()) / self.name
69-
root.mkdir(exist_ok=True)
69+
mock_data_folder = root / "__mock__"
70+
mock_data_folder.mkdir(parents=True)
7071

71-
mock_info = self._parse_mock_info(self.mock_data_fn(root, config))
72+
mock_info = self._parse_mock_info(self.mock_data_fn(mock_data_folder, config))
7273

73-
with unittest.mock.patch.object(datasets.utils.Dataset, "__init__"):
74-
required_file_names = {
75-
resource.file_name for resource in datasets.load(self.name, root=root, **config)._resources()
76-
}
77-
available_file_names = {path.name for path in root.glob("*")}
78-
missing_file_names = required_file_names - available_file_names
79-
if missing_file_names:
80-
raise pytest.UsageError(
81-
f"Dataset '{self.name}' requires the files {sequence_to_str(sorted(missing_file_names))} "
82-
f"for {config}, but they were not created by the mock data function."
83-
)
74+
def mock_data_download(resource, root, **kwargs):
75+
src = mock_data_folder / resource.file_name
76+
if not src.exists():
77+
raise pytest.UsageError(
78+
f"Dataset '{self.name}' requires the file {resource.file_name} for {config}"
79+
f"but it was not created by the mock data function."
80+
)
8481

85-
return mock_info
82+
dst = root / resource.file_name
83+
shutil.move(str(src), str(root))
84+
85+
return dst
86+
87+
with unittest.mock.patch(
88+
"torchvision.prototype.datasets.utils._resource.OnlineResource.download", new=mock_data_download
89+
):
90+
dataset = datasets.load(self.name, **config)
91+
92+
extra_files = list(mock_data_folder.glob("**/*"))
93+
if not extra_files:
94+
mock_data_folder.rmdir()
95+
else:
96+
pass
97+
# raise pytest.UsageError(
98+
# (
99+
# f"Dataset '{self.name}' created the following files for {config} in the mock data function, "
100+
# f"but they were not loaded:\n\n"
101+
# )
102+
# + "\n".join(str(file.relative_to(mock_data_folder)) for file in extra_files)
103+
# )
104+
105+
return dataset, mock_info
86106

87107

88108
def config_id(name, config):

test/test_prototype_builtin_datasets.py

Lines changed: 16 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -56,18 +56,14 @@ def test_info(self, name):
5656

5757
@parametrize_dataset_mocks(DATASET_MOCKS)
5858
def test_smoke(self, dataset_mock, config):
59-
dataset_mock.prepare(config)
60-
61-
dataset = datasets.load(dataset_mock.name, **config)
59+
dataset, _ = dataset_mock.load(config)
6260

6361
if not isinstance(dataset, datasets.utils.Dataset):
6462
raise AssertionError(f"Loading the dataset should return an Dataset, but got {type(dataset)} instead.")
6563

6664
@parametrize_dataset_mocks(DATASET_MOCKS)
6765
def test_sample(self, dataset_mock, config):
68-
dataset_mock.prepare(config)
69-
70-
dataset = datasets.load(dataset_mock.name, **config)
66+
dataset, _ = dataset_mock.load(config)
7167

7268
try:
7369
sample = next(iter(dataset))
@@ -84,17 +80,13 @@ def test_sample(self, dataset_mock, config):
8480

8581
@parametrize_dataset_mocks(DATASET_MOCKS)
8682
def test_num_samples(self, dataset_mock, config):
87-
mock_info = dataset_mock.prepare(config)
88-
89-
dataset = datasets.load(dataset_mock.name, **config)
83+
dataset, mock_info = dataset_mock.load(config)
9084

9185
assert len(list(dataset)) == mock_info["num_samples"]
9286

9387
@parametrize_dataset_mocks(DATASET_MOCKS)
9488
def test_no_vanilla_tensors(self, dataset_mock, config):
95-
dataset_mock.prepare(config)
96-
97-
dataset = datasets.load(dataset_mock.name, **config)
89+
dataset, _ = dataset_mock.load(config)
9890

9991
vanilla_tensors = {key for key, value in next(iter(dataset)).items() if type(value) is torch.Tensor}
10092
if vanilla_tensors:
@@ -105,24 +97,20 @@ def test_no_vanilla_tensors(self, dataset_mock, config):
10597

10698
@parametrize_dataset_mocks(DATASET_MOCKS)
10799
def test_transformable(self, dataset_mock, config):
108-
dataset_mock.prepare(config)
109-
110-
dataset = datasets.load(dataset_mock.name, **config)
100+
dataset, _ = dataset_mock.load(config)
111101

112102
next(iter(dataset.map(transforms.Identity())))
113103

114104
@pytest.mark.parametrize("only_datapipe", [False, True])
115105
@parametrize_dataset_mocks(DATASET_MOCKS)
116106
def test_traversable(self, dataset_mock, config, only_datapipe):
117-
dataset_mock.prepare(config)
118-
dataset = datasets.load(dataset_mock.name, **config)
107+
dataset, _ = dataset_mock.load(config)
119108

120109
traverse(dataset, only_datapipe=only_datapipe)
121110

122111
@parametrize_dataset_mocks(DATASET_MOCKS)
123112
def test_serializable(self, dataset_mock, config):
124-
dataset_mock.prepare(config)
125-
dataset = datasets.load(dataset_mock.name, **config)
113+
dataset, _ = dataset_mock.load(config)
126114

127115
pickle.dumps(dataset)
128116

@@ -135,8 +123,7 @@ def _collate_fn(self, batch):
135123
@pytest.mark.parametrize("num_workers", [0, 1])
136124
@parametrize_dataset_mocks(DATASET_MOCKS)
137125
def test_data_loader(self, dataset_mock, config, num_workers):
138-
dataset_mock.prepare(config)
139-
dataset = datasets.load(dataset_mock.name, **config)
126+
dataset, _ = dataset_mock.load(config)
140127

141128
dl = DataLoader(
142129
dataset,
@@ -153,17 +140,15 @@ def test_data_loader(self, dataset_mock, config, num_workers):
153140
@parametrize_dataset_mocks(DATASET_MOCKS)
154141
@pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter))
155142
def test_has_annotations(self, dataset_mock, config, annotation_dp_type):
156-
157-
dataset_mock.prepare(config)
158-
dataset = datasets.load(dataset_mock.name, **config)
143+
dataset, _ = dataset_mock.load(config)
159144

160145
if not any(isinstance(dp, annotation_dp_type) for dp in extract_datapipes(dataset)):
161146
raise AssertionError(f"The dataset doesn't contain a {annotation_dp_type.__name__}() datapipe.")
162147

163148
@parametrize_dataset_mocks(DATASET_MOCKS)
164149
def test_save_load(self, dataset_mock, config):
165-
dataset_mock.prepare(config)
166-
dataset = datasets.load(dataset_mock.name, **config)
150+
dataset, _ = dataset_mock.load(config)
151+
167152
sample = next(iter(dataset))
168153

169154
with io.BytesIO() as buffer:
@@ -173,8 +158,7 @@ def test_save_load(self, dataset_mock, config):
173158

174159
@parametrize_dataset_mocks(DATASET_MOCKS)
175160
def test_infinite_buffer_size(self, dataset_mock, config):
176-
dataset_mock.prepare(config)
177-
dataset = datasets.load(dataset_mock.name, **config)
161+
dataset, _ = dataset_mock.load(config)
178162

179163
for dp in extract_datapipes(dataset):
180164
if hasattr(dp, "buffer_size"):
@@ -184,18 +168,15 @@ def test_infinite_buffer_size(self, dataset_mock, config):
184168

185169
@parametrize_dataset_mocks(DATASET_MOCKS)
186170
def test_has_length(self, dataset_mock, config):
187-
dataset_mock.prepare(config)
188-
dataset = datasets.load(dataset_mock.name, **config)
171+
dataset, _ = dataset_mock.load(config)
189172

190173
assert len(dataset) > 0
191174

192175

193176
@parametrize_dataset_mocks(DATASET_MOCKS["qmnist"])
194177
class TestQMNIST:
195178
def test_extra_label(self, dataset_mock, config):
196-
dataset_mock.prepare(config)
197-
198-
dataset = datasets.load(dataset_mock.name, **config)
179+
dataset, _ = dataset_mock.load(config)
199180

200181
sample = next(iter(dataset))
201182
for key, type in (
@@ -218,9 +199,7 @@ def test_label_matches_path(self, dataset_mock, config):
218199
if config["split"] != "train":
219200
return
220201

221-
dataset_mock.prepare(config)
222-
223-
dataset = datasets.load(dataset_mock.name, **config)
202+
dataset, _ = dataset_mock.load(config)
224203

225204
for sample in dataset:
226205
label_from_path = int(Path(sample["path"]).parent.name)
@@ -230,9 +209,7 @@ def test_label_matches_path(self, dataset_mock, config):
230209
@parametrize_dataset_mocks(DATASET_MOCKS["usps"])
231210
class TestUSPS:
232211
def test_sample_content(self, dataset_mock, config):
233-
dataset_mock.prepare(config)
234-
235-
dataset = datasets.load(dataset_mock.name, **config)
212+
dataset, _ = dataset_mock.load(config)
236213

237214
for sample in dataset:
238215
assert "image" in sample

0 commit comments

Comments
 (0)