|
| 1 | +# Copyright (c) Facebook, Inc. and its affiliates. |
| 2 | +import expecttest |
| 3 | +import os |
| 4 | +import unittest |
| 5 | +import warnings |
| 6 | + |
| 7 | +from torchdata.datapipes.iter import ( |
| 8 | + FileLister, |
| 9 | + IterableWrapper, |
| 10 | + FSSpecFileLister, |
| 11 | + FSSpecFileOpener, |
| 12 | + FSSpecSaver, |
| 13 | +) |
| 14 | + |
| 15 | +from _utils._common_utils_for_test import ( |
| 16 | + create_temp_dir, |
| 17 | + create_temp_files, |
| 18 | + reset_after_n_next_calls, |
| 19 | +) |
| 20 | + |
| 21 | +try: |
| 22 | + import fsspec |
| 23 | + |
| 24 | + HAS_FSSPEC = True |
| 25 | +except ImportError: |
| 26 | + HAS_FSSPEC = False |
| 27 | +skipIfNoFSSpec = unittest.skipIf(not HAS_FSSPEC, "no fsspec") |
| 28 | + |
| 29 | + |
| 30 | +class TestDataPipeFSSpec(expecttest.TestCase): |
| 31 | + def setUp(self): |
| 32 | + self.temp_dir = create_temp_dir() |
| 33 | + self.temp_files = create_temp_files(self.temp_dir) |
| 34 | + self.temp_sub_dir = create_temp_dir(self.temp_dir.name) |
| 35 | + self.temp_sub_files = create_temp_files(self.temp_sub_dir, 4, False) |
| 36 | + |
| 37 | + def tearDown(self): |
| 38 | + try: |
| 39 | + self.temp_sub_dir.cleanup() |
| 40 | + self.temp_dir.cleanup() |
| 41 | + except Exception as e: |
| 42 | + warnings.warn( |
| 43 | + f"TestDataPipeLocalIO was not able to cleanup temp dir due to {e}" |
| 44 | + ) |
| 45 | + |
| 46 | + def _write_text_files(self): |
| 47 | + def filepath_fn(name: str) -> str: |
| 48 | + return os.path.join(self.temp_dir.name, os.path.basename(name)) |
| 49 | + |
| 50 | + name_to_data = {"1.text": b"DATA", "2.text": b"DATA", "3.text": b"DATA"} |
| 51 | + source_dp = IterableWrapper(sorted(name_to_data.items())) |
| 52 | + saver_dp = source_dp.save_to_disk(filepath_fn=filepath_fn, mode="wb") |
| 53 | + list(saver_dp) |
| 54 | + |
| 55 | + @skipIfNoFSSpec |
| 56 | + def test_fsspec_file_lister_iterdatapipe(self): |
| 57 | + datapipe = FSSpecFileLister(root="file://" + self.temp_sub_dir.name) |
| 58 | + |
| 59 | + # check all file paths within sub_folder are listed |
| 60 | + for path in datapipe: |
| 61 | + self.assertIn( |
| 62 | + path.split("://")[1], |
| 63 | + { |
| 64 | + fsspec.implementations.local.make_path_posix(file) |
| 65 | + for file in self.temp_sub_files |
| 66 | + }, |
| 67 | + ) |
| 68 | + |
| 69 | + @skipIfNoFSSpec |
| 70 | + def test_fsspec_file_loader_iterdatapipe(self): |
| 71 | + datapipe1 = FSSpecFileLister(root="file://" + self.temp_sub_dir.name) |
| 72 | + datapipe2 = FSSpecFileOpener(datapipe1) |
| 73 | + |
| 74 | + # check contents of file match |
| 75 | + for _, f in datapipe2: |
| 76 | + self.assertEqual(f.read(), "0123456789abcdef") |
| 77 | + |
| 78 | + # Reset Test: Ensure the resulting streams are still readable after the DataPipe is reset/exhausted |
| 79 | + self._write_text_files() |
| 80 | + lister_dp = FileLister(self.temp_dir.name, "*.text") |
| 81 | + fsspec_file_loader_dp = FSSpecFileOpener(lister_dp, mode="rb") |
| 82 | + |
| 83 | + n_elements_before_reset = 2 |
| 84 | + res_before_reset, res_after_reset = reset_after_n_next_calls( |
| 85 | + fsspec_file_loader_dp, n_elements_before_reset |
| 86 | + ) |
| 87 | + self.assertEqual(2, len(res_before_reset)) |
| 88 | + self.assertEqual(3, len(res_after_reset)) |
| 89 | + for _name, stream in res_before_reset: |
| 90 | + self.assertEqual(b"DATA", stream.read()) |
| 91 | + for _name, stream in res_after_reset: |
| 92 | + self.assertEqual(b"DATA", stream.read()) |
| 93 | + |
| 94 | + @skipIfNoFSSpec |
| 95 | + def test_fsspec_saver_iterdatapipe(self): |
| 96 | + def filepath_fn(name: str) -> str: |
| 97 | + return "file://" + os.path.join(self.temp_dir.name, os.path.basename(name)) |
| 98 | + |
| 99 | + # Functional Test: Saving some data |
| 100 | + name_to_data = {"1.txt": b"DATA1", "2.txt": b"DATA2", "3.txt": b"DATA3"} |
| 101 | + source_dp = IterableWrapper(sorted(name_to_data.items())) |
| 102 | + saver_dp = source_dp.save_by_fsspec(filepath_fn=filepath_fn, mode="wb") |
| 103 | + res_file_paths = list(saver_dp) |
| 104 | + expected_paths = [filepath_fn(name) for name in name_to_data.keys()] |
| 105 | + self.assertEqual(expected_paths, res_file_paths) |
| 106 | + for name in name_to_data.keys(): |
| 107 | + p = filepath_fn(name).split("://")[1] |
| 108 | + with open(p, "r") as f: |
| 109 | + self.assertEqual(name_to_data[name], f.read().encode()) |
| 110 | + |
| 111 | + # Reset Test: |
| 112 | + saver_dp = FSSpecSaver(source_dp, filepath_fn=filepath_fn, mode="wb") |
| 113 | + n_elements_before_reset = 2 |
| 114 | + res_before_reset, res_after_reset = reset_after_n_next_calls( |
| 115 | + saver_dp, n_elements_before_reset |
| 116 | + ) |
| 117 | + self.assertEqual([filepath_fn("1.txt"), filepath_fn("2.txt")], res_before_reset) |
| 118 | + self.assertEqual(expected_paths, res_after_reset) |
| 119 | + for name in name_to_data.keys(): |
| 120 | + p = filepath_fn(name).split("://")[1] |
| 121 | + with open(p, "r") as f: |
| 122 | + self.assertEqual(name_to_data[name], f.read().encode()) |
| 123 | + |
| 124 | + # __len__ Test: returns the length of source DataPipe |
| 125 | + self.assertEqual(3, len(saver_dp)) |
| 126 | + |
| 127 | + @skipIfNoFSSpec |
| 128 | + def test_fsspec_memory_list(self): |
| 129 | + fs = fsspec.filesystem("memory") |
| 130 | + fs.mkdir("foo") |
| 131 | + fs.touch("foo/bar1") |
| 132 | + fs.touch("foo/bar2") |
| 133 | + |
| 134 | + datapipe = FSSpecFileLister(root="memory://foo") |
| 135 | + self.assertEqual(set(datapipe), {"memory:///foo/bar1", "memory:///foo/bar2"}) |
| 136 | + |
| 137 | + datapipe = FSSpecFileLister(root="memory://foo/bar1") |
| 138 | + self.assertEqual(set(datapipe), {"memory://foo/bar1"}) |
| 139 | + |
| 140 | + @skipIfNoFSSpec |
| 141 | + def test_fsspec_memory_load(self): |
| 142 | + fs = fsspec.filesystem("memory") |
| 143 | + with fs.open("file", "w") as f: |
| 144 | + f.write("hello") |
| 145 | + with fs.open("file2", "w") as f: |
| 146 | + f.write("hello2") |
| 147 | + |
| 148 | + files = ["memory://file", "memory://file2"] |
| 149 | + datapipe = FSSpecFileOpener(files) |
| 150 | + self.assertEqual([f.read() for _, f in datapipe], ["hello", "hello2"]) |
| 151 | + |
| 152 | + @skipIfNoFSSpec |
| 153 | + def test_fsspec_memory_save(self): |
| 154 | + def filepath_fn(name: str) -> str: |
| 155 | + return "memory://" + name |
| 156 | + |
| 157 | + name_to_data = {"1.txt": b"DATA1", "2.txt": b"DATA2"} |
| 158 | + source_dp = IterableWrapper(sorted(name_to_data.items())) |
| 159 | + saver_dp = FSSpecSaver(source_dp, filepath_fn=filepath_fn, mode="wb") |
| 160 | + |
| 161 | + self.assertEqual(set(saver_dp), {"memory://1.txt", "memory://2.txt"}) |
| 162 | + |
| 163 | + |
| 164 | +if __name__ == "__main__": |
| 165 | + unittest.main() |
0 commit comments