Skip to content

Commit b1d2199

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] rely on patched datasets home rather than passing it around (#5998)
Summary: * rely on patched datasets home rather than passing it around * add comment Reviewed By: datumbox Differential Revision: D36413359 fbshipit-source-id: cccff560e6e29da8401fab866a34ebe0fdde6b9a
1 parent e70d4e3 commit b1d2199

File tree

2 files changed

+36
-33
lines changed

2 files changed

+36
-33
lines changed

test/builtin_dataset_mocks.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,10 @@ def _parse_mock_info(self, mock_info):
6262

6363
return mock_info
6464

65-
def prepare(self, home, config):
66-
root = home / self.name
65+
def prepare(self, config):
66+
# `datasets.home()` is patched to a temporary directory through the autouse fixture `test_home` in
67+
# test/test_prototype_builtin_datasets.py
68+
root = pathlib.Path(datasets.home()) / self.name
6769
root.mkdir(exist_ok=True)
6870

6971
mock_info = self._parse_mock_info(self.mock_data_fn(root, config))

test/test_prototype_builtin_datasets.py

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,10 @@ def extract_datapipes(dp):
2525
return get_all_graph_pipes(traverse(dp, only_datapipe=True))
2626

2727

28-
@pytest.fixture
28+
@pytest.fixture(autouse=True)
2929
def test_home(mocker, tmp_path):
3030
mocker.patch("torchvision.prototype.datasets._api.home", return_value=str(tmp_path))
31+
mocker.patch("torchvision.prototype.datasets.home", return_value=str(tmp_path))
3132
yield tmp_path
3233

3334

@@ -54,17 +55,17 @@ def test_info(self, name):
5455
raise AssertionError("Info should be a dictionary with string keys.")
5556

5657
@parametrize_dataset_mocks(DATASET_MOCKS)
57-
def test_smoke(self, test_home, dataset_mock, config):
58-
dataset_mock.prepare(test_home, config)
58+
def test_smoke(self, dataset_mock, config):
59+
dataset_mock.prepare(config)
5960

6061
dataset = datasets.load(dataset_mock.name, **config)
6162

6263
if not isinstance(dataset, datasets.utils.Dataset):
6364
raise AssertionError(f"Loading the dataset should return an Dataset, but got {type(dataset)} instead.")
6465

6566
@parametrize_dataset_mocks(DATASET_MOCKS)
66-
def test_sample(self, test_home, dataset_mock, config):
67-
dataset_mock.prepare(test_home, config)
67+
def test_sample(self, dataset_mock, config):
68+
dataset_mock.prepare(config)
6869

6970
dataset = datasets.load(dataset_mock.name, **config)
7071

@@ -82,16 +83,16 @@ def test_sample(self, test_home, dataset_mock, config):
8283
raise AssertionError("Sample dictionary is empty.")
8384

8485
@parametrize_dataset_mocks(DATASET_MOCKS)
85-
def test_num_samples(self, test_home, dataset_mock, config):
86-
mock_info = dataset_mock.prepare(test_home, config)
86+
def test_num_samples(self, dataset_mock, config):
87+
mock_info = dataset_mock.prepare(config)
8788

8889
dataset = datasets.load(dataset_mock.name, **config)
8990

9091
assert len(list(dataset)) == mock_info["num_samples"]
9192

9293
@parametrize_dataset_mocks(DATASET_MOCKS)
93-
def test_no_vanilla_tensors(self, test_home, dataset_mock, config):
94-
dataset_mock.prepare(test_home, config)
94+
def test_no_vanilla_tensors(self, dataset_mock, config):
95+
dataset_mock.prepare(config)
9596

9697
dataset = datasets.load(dataset_mock.name, **config)
9798

@@ -103,24 +104,24 @@ def test_no_vanilla_tensors(self, test_home, dataset_mock, config):
103104
)
104105

105106
@parametrize_dataset_mocks(DATASET_MOCKS)
106-
def test_transformable(self, test_home, dataset_mock, config):
107-
dataset_mock.prepare(test_home, config)
107+
def test_transformable(self, dataset_mock, config):
108+
dataset_mock.prepare(config)
108109

109110
dataset = datasets.load(dataset_mock.name, **config)
110111

111112
next(iter(dataset.map(transforms.Identity())))
112113

113114
@pytest.mark.parametrize("only_datapipe", [False, True])
114115
@parametrize_dataset_mocks(DATASET_MOCKS)
115-
def test_traversable(self, test_home, dataset_mock, config, only_datapipe):
116-
dataset_mock.prepare(test_home, config)
116+
def test_traversable(self, dataset_mock, config, only_datapipe):
117+
dataset_mock.prepare(config)
117118
dataset = datasets.load(dataset_mock.name, **config)
118119

119120
traverse(dataset, only_datapipe=only_datapipe)
120121

121122
@parametrize_dataset_mocks(DATASET_MOCKS)
122-
def test_serializable(self, test_home, dataset_mock, config):
123-
dataset_mock.prepare(test_home, config)
123+
def test_serializable(self, dataset_mock, config):
124+
dataset_mock.prepare(config)
124125
dataset = datasets.load(dataset_mock.name, **config)
125126

126127
pickle.dumps(dataset)
@@ -133,8 +134,8 @@ def _collate_fn(self, batch):
133134

134135
@pytest.mark.parametrize("num_workers", [0, 1])
135136
@parametrize_dataset_mocks(DATASET_MOCKS)
136-
def test_data_loader(self, test_home, dataset_mock, config, num_workers):
137-
dataset_mock.prepare(test_home, config)
137+
def test_data_loader(self, dataset_mock, config, num_workers):
138+
dataset_mock.prepare(config)
138139
dataset = datasets.load(dataset_mock.name, **config)
139140

140141
dl = DataLoader(
@@ -151,17 +152,17 @@ def test_data_loader(self, test_home, dataset_mock, config, num_workers):
151152
# contain a custom test for that, but we opted to wait for a potential solution / test from torchdata for now.
152153
@parametrize_dataset_mocks(DATASET_MOCKS)
153154
@pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter))
154-
def test_has_annotations(self, test_home, dataset_mock, config, annotation_dp_type):
155+
def test_has_annotations(self, dataset_mock, config, annotation_dp_type):
155156

156-
dataset_mock.prepare(test_home, config)
157+
dataset_mock.prepare(config)
157158
dataset = datasets.load(dataset_mock.name, **config)
158159

159160
if not any(isinstance(dp, annotation_dp_type) for dp in extract_datapipes(dataset)):
160161
raise AssertionError(f"The dataset doesn't contain a {annotation_dp_type.__name__}() datapipe.")
161162

162163
@parametrize_dataset_mocks(DATASET_MOCKS)
163-
def test_save_load(self, test_home, dataset_mock, config):
164-
dataset_mock.prepare(test_home, config)
164+
def test_save_load(self, dataset_mock, config):
165+
dataset_mock.prepare(config)
165166
dataset = datasets.load(dataset_mock.name, **config)
166167
sample = next(iter(dataset))
167168

@@ -171,8 +172,8 @@ def test_save_load(self, test_home, dataset_mock, config):
171172
assert_samples_equal(torch.load(buffer), sample)
172173

173174
@parametrize_dataset_mocks(DATASET_MOCKS)
174-
def test_infinite_buffer_size(self, test_home, dataset_mock, config):
175-
dataset_mock.prepare(test_home, config)
175+
def test_infinite_buffer_size(self, dataset_mock, config):
176+
dataset_mock.prepare(config)
176177
dataset = datasets.load(dataset_mock.name, **config)
177178

178179
for dp in extract_datapipes(dataset):
@@ -182,17 +183,17 @@ def test_infinite_buffer_size(self, test_home, dataset_mock, config):
182183
assert dp.buffer_size == INFINITE_BUFFER_SIZE
183184

184185
@parametrize_dataset_mocks(DATASET_MOCKS)
185-
def test_has_length(self, test_home, dataset_mock, config):
186-
dataset_mock.prepare(test_home, config)
186+
def test_has_length(self, dataset_mock, config):
187+
dataset_mock.prepare(config)
187188
dataset = datasets.load(dataset_mock.name, **config)
188189

189190
assert len(dataset) > 0
190191

191192

192193
@parametrize_dataset_mocks(DATASET_MOCKS["qmnist"])
193194
class TestQMNIST:
194-
def test_extra_label(self, test_home, dataset_mock, config):
195-
dataset_mock.prepare(test_home, config)
195+
def test_extra_label(self, dataset_mock, config):
196+
dataset_mock.prepare(config)
196197

197198
dataset = datasets.load(dataset_mock.name, **config)
198199

@@ -211,13 +212,13 @@ def test_extra_label(self, test_home, dataset_mock, config):
211212

212213
@parametrize_dataset_mocks(DATASET_MOCKS["gtsrb"])
213214
class TestGTSRB:
214-
def test_label_matches_path(self, test_home, dataset_mock, config):
215+
def test_label_matches_path(self, dataset_mock, config):
215216
# We read the labels from the csv files instead. But for the trainset, the labels are also part of the path.
216217
# This test makes sure that they're both the same
217218
if config["split"] != "train":
218219
return
219220

220-
dataset_mock.prepare(test_home, config)
221+
dataset_mock.prepare(config)
221222

222223
dataset = datasets.load(dataset_mock.name, **config)
223224

@@ -228,8 +229,8 @@ def test_label_matches_path(self, test_home, dataset_mock, config):
228229

229230
@parametrize_dataset_mocks(DATASET_MOCKS["usps"])
230231
class TestUSPS:
231-
def test_sample_content(self, test_home, dataset_mock, config):
232-
dataset_mock.prepare(test_home, config)
232+
def test_sample_content(self, dataset_mock, config):
233+
dataset_mock.prepare(config)
233234

234235
dataset = datasets.load(dataset_mock.name, **config)
235236

0 commit comments

Comments
 (0)