diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index 768d286e890..cea0f297be5 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -62,8 +62,10 @@ def _parse_mock_info(self, mock_info): return mock_info - def prepare(self, home, config): - root = home / self.name + def prepare(self, config): + # `datasets.home()` is patched to a temporary directory through the autouse fixture `test_home` in + # test/test_prototype_builtin_datasets.py + root = pathlib.Path(datasets.home()) / self.name root.mkdir(exist_ok=True) mock_info = self._parse_mock_info(self.mock_data_fn(root, config)) diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index 842a0048afc..23190b25ddc 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -25,9 +25,10 @@ def extract_datapipes(dp): return get_all_graph_pipes(traverse(dp, only_datapipe=True)) -@pytest.fixture +@pytest.fixture(autouse=True) def test_home(mocker, tmp_path): mocker.patch("torchvision.prototype.datasets._api.home", return_value=str(tmp_path)) + mocker.patch("torchvision.prototype.datasets.home", return_value=str(tmp_path)) yield tmp_path @@ -54,8 +55,8 @@ def test_info(self, name): raise AssertionError("Info should be a dictionary with string keys.") @parametrize_dataset_mocks(DATASET_MOCKS) - def test_smoke(self, test_home, dataset_mock, config): - dataset_mock.prepare(test_home, config) + def test_smoke(self, dataset_mock, config): + dataset_mock.prepare(config) dataset = datasets.load(dataset_mock.name, **config) @@ -63,8 +64,8 @@ def test_smoke(self, test_home, dataset_mock, config): raise AssertionError(f"Loading the dataset should return an Dataset, but got {type(dataset)} instead.") @parametrize_dataset_mocks(DATASET_MOCKS) - def test_sample(self, test_home, dataset_mock, config): - dataset_mock.prepare(test_home, config) + def test_sample(self, dataset_mock, config): + dataset_mock.prepare(config) dataset = datasets.load(dataset_mock.name, **config) @@ -82,16 +83,16 @@ def test_sample(self, test_home, dataset_mock, config): raise AssertionError("Sample dictionary is empty.") @parametrize_dataset_mocks(DATASET_MOCKS) - def test_num_samples(self, test_home, dataset_mock, config): - mock_info = dataset_mock.prepare(test_home, config) + def test_num_samples(self, dataset_mock, config): + mock_info = dataset_mock.prepare(config) dataset = datasets.load(dataset_mock.name, **config) assert len(list(dataset)) == mock_info["num_samples"] @parametrize_dataset_mocks(DATASET_MOCKS) - def test_no_vanilla_tensors(self, test_home, dataset_mock, config): - dataset_mock.prepare(test_home, config) + def test_no_vanilla_tensors(self, dataset_mock, config): + dataset_mock.prepare(config) dataset = datasets.load(dataset_mock.name, **config) @@ -103,8 +104,8 @@ def test_no_vanilla_tensors(self, test_home, dataset_mock, config): ) @parametrize_dataset_mocks(DATASET_MOCKS) - def test_transformable(self, test_home, dataset_mock, config): - dataset_mock.prepare(test_home, config) + def test_transformable(self, dataset_mock, config): + dataset_mock.prepare(config) dataset = datasets.load(dataset_mock.name, **config) @@ -112,15 +113,15 @@ def test_transformable(self, test_home, dataset_mock, config): @pytest.mark.parametrize("only_datapipe", [False, True]) @parametrize_dataset_mocks(DATASET_MOCKS) - def test_traversable(self, test_home, dataset_mock, config, only_datapipe): - dataset_mock.prepare(test_home, config) + def test_traversable(self, dataset_mock, config, only_datapipe): + dataset_mock.prepare(config) dataset = datasets.load(dataset_mock.name, **config) traverse(dataset, only_datapipe=only_datapipe) @parametrize_dataset_mocks(DATASET_MOCKS) - def test_serializable(self, test_home, dataset_mock, config): - dataset_mock.prepare(test_home, config) + def test_serializable(self, dataset_mock, config): + dataset_mock.prepare(config) dataset = datasets.load(dataset_mock.name, **config) pickle.dumps(dataset) @@ -133,8 +134,8 @@ def _collate_fn(self, batch): @pytest.mark.parametrize("num_workers", [0, 1]) @parametrize_dataset_mocks(DATASET_MOCKS) - def test_data_loader(self, test_home, dataset_mock, config, num_workers): - dataset_mock.prepare(test_home, config) + def test_data_loader(self, dataset_mock, config, num_workers): + dataset_mock.prepare(config) dataset = datasets.load(dataset_mock.name, **config) dl = DataLoader( @@ -151,17 +152,17 @@ def test_data_loader(self, test_home, dataset_mock, config, num_workers): # contain a custom test for that, but we opted to wait for a potential solution / test from torchdata for now. @parametrize_dataset_mocks(DATASET_MOCKS) @pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter)) - def test_has_annotations(self, test_home, dataset_mock, config, annotation_dp_type): + def test_has_annotations(self, dataset_mock, config, annotation_dp_type): - dataset_mock.prepare(test_home, config) + dataset_mock.prepare(config) dataset = datasets.load(dataset_mock.name, **config) if not any(isinstance(dp, annotation_dp_type) for dp in extract_datapipes(dataset)): raise AssertionError(f"The dataset doesn't contain a {annotation_dp_type.__name__}() datapipe.") @parametrize_dataset_mocks(DATASET_MOCKS) - def test_save_load(self, test_home, dataset_mock, config): - dataset_mock.prepare(test_home, config) + def test_save_load(self, dataset_mock, config): + dataset_mock.prepare(config) dataset = datasets.load(dataset_mock.name, **config) sample = next(iter(dataset)) @@ -171,8 +172,8 @@ def test_save_load(self, test_home, dataset_mock, config): assert_samples_equal(torch.load(buffer), sample) @parametrize_dataset_mocks(DATASET_MOCKS) - def test_infinite_buffer_size(self, test_home, dataset_mock, config): - dataset_mock.prepare(test_home, config) + def test_infinite_buffer_size(self, dataset_mock, config): + dataset_mock.prepare(config) dataset = datasets.load(dataset_mock.name, **config) for dp in extract_datapipes(dataset): @@ -182,8 +183,8 @@ def test_infinite_buffer_size(self, test_home, dataset_mock, config): assert dp.buffer_size == INFINITE_BUFFER_SIZE @parametrize_dataset_mocks(DATASET_MOCKS) - def test_has_length(self, test_home, dataset_mock, config): - dataset_mock.prepare(test_home, config) + def test_has_length(self, dataset_mock, config): + dataset_mock.prepare(config) dataset = datasets.load(dataset_mock.name, **config) assert len(dataset) > 0 @@ -191,8 +192,8 @@ def test_has_length(self, test_home, dataset_mock, config): @parametrize_dataset_mocks(DATASET_MOCKS["qmnist"]) class TestQMNIST: - def test_extra_label(self, test_home, dataset_mock, config): - dataset_mock.prepare(test_home, config) + def test_extra_label(self, dataset_mock, config): + dataset_mock.prepare(config) dataset = datasets.load(dataset_mock.name, **config) @@ -211,13 +212,13 @@ def test_extra_label(self, test_home, dataset_mock, config): @parametrize_dataset_mocks(DATASET_MOCKS["gtsrb"]) class TestGTSRB: - def test_label_matches_path(self, test_home, dataset_mock, config): + def test_label_matches_path(self, dataset_mock, config): # We read the labels from the csv files instead. But for the trainset, the labels are also part of the path. # This test makes sure that they're both the same if config["split"] != "train": return - dataset_mock.prepare(test_home, config) + dataset_mock.prepare(config) dataset = datasets.load(dataset_mock.name, **config) @@ -228,8 +229,8 @@ def test_label_matches_path(self, test_home, dataset_mock, config): @parametrize_dataset_mocks(DATASET_MOCKS["usps"]) class TestUSPS: - def test_sample_content(self, test_home, dataset_mock, config): - dataset_mock.prepare(test_home, config) + def test_sample_content(self, dataset_mock, config): + dataset_mock.prepare(config) dataset = datasets.load(dataset_mock.name, **config)