Skip to content

Commit df7db25

Browse files
Yosua Michael Maranathafacebook-github-bot
Yosua Michael Maranatha
authored andcommitted
[fbsync] close streams in prototype datasets (#6647)
Summary: * close streams in prototype datasets * refactor prototype SBD to avoid closing demux streams at construction time * mypy Reviewed By: NicolasHug Differential Revision: D40427477 fbshipit-source-id: 854554f283ff281f8c9eb0e2786644116a4b4dd8
1 parent e0a77a9 commit df7db25

File tree

11 files changed

+135
-67
lines changed

11 files changed

+135
-67
lines changed

test/builtin_dataset_mocks.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -661,15 +661,15 @@ class SBDMockData:
661661
_NUM_CATEGORIES = 20
662662

663663
@classmethod
664-
def _make_split_files(cls, root_map):
665-
ids_map = {
666-
split: [f"2008_{idx:06d}" for idx in idcs]
667-
for split, idcs in (
668-
("train", [0, 1, 2]),
669-
("train_noval", [0, 2]),
670-
("val", [3]),
671-
)
672-
}
664+
def _make_split_files(cls, root_map, *, split):
665+
splits_and_idcs = [
666+
("train", [0, 1, 2]),
667+
("val", [3]),
668+
]
669+
if split == "train_noval":
670+
splits_and_idcs.append(("train_noval", [0, 2]))
671+
672+
ids_map = {split: [f"2008_{idx:06d}" for idx in idcs] for split, idcs in splits_and_idcs}
673673

674674
for split, ids in ids_map.items():
675675
with open(root_map[split] / f"{split}.txt", "w") as fh:
@@ -710,25 +710,27 @@ def _make_segmentation(cls, size):
710710
return torch.randint(0, cls._NUM_CATEGORIES + 1, size=size, dtype=torch.uint8).numpy()
711711

712712
@classmethod
713-
def generate(cls, root):
713+
def generate(cls, root, *, split):
714714
archive_folder = root / "benchmark_RELEASE"
715715
dataset_folder = archive_folder / "dataset"
716716
dataset_folder.mkdir(parents=True, exist_ok=True)
717717

718-
ids, num_samples_map = cls._make_split_files(defaultdict(lambda: dataset_folder, {"train_noval": root}))
718+
ids, num_samples_map = cls._make_split_files(
719+
defaultdict(lambda: dataset_folder, {"train_noval": root}), split=split
720+
)
719721
sizes = cls._make_anns_folder(dataset_folder, "cls", ids)
720722
create_image_folder(
721723
dataset_folder, "img", lambda idx: f"{ids[idx]}.jpg", num_examples=len(ids), size=lambda idx: sizes[idx]
722724
)
723725

724726
make_tar(root, "benchmark.tgz", archive_folder, compression="gz")
725727

726-
return num_samples_map
728+
return num_samples_map[split]
727729

728730

729731
@register_mock(configs=combinations_grid(split=("train", "val", "train_noval")))
730732
def sbd(root, config):
731-
return SBDMockData.generate(root)[config["split"]]
733+
return SBDMockData.generate(root, split=config["split"])
732734

733735

734736
@register_mock(configs=[dict()])

test/test_prototype_datasets_builtin.py

Lines changed: 57 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import functools
22
import io
33
import pickle
4+
from collections import deque
45
from pathlib import Path
56

67
import pytest
@@ -11,10 +12,11 @@
1112
from torch.utils.data.graph import traverse_dps
1213
from torch.utils.data.graph_settings import get_all_graph_pipes
1314
from torchdata.datapipes.iter import ShardingFilter, Shuffler
15+
from torchdata.datapipes.utils import StreamWrapper
1416
from torchvision._utils import sequence_to_str
15-
from torchvision.prototype import datasets, transforms
17+
from torchvision.prototype import datasets, features, transforms
1618
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE
17-
from torchvision.prototype.features import Image, Label
19+
1820

1921
assert_samples_equal = functools.partial(
2022
assert_equal, pair_types=(TensorLikePair, ObjectPair), rtol=0, atol=0, equal_nan=True
@@ -25,6 +27,17 @@ def extract_datapipes(dp):
2527
return get_all_graph_pipes(traverse_dps(dp))
2628

2729

30+
def consume(iterator):
31+
# Copied from the official itertools recipes: https://docs.python.org/3/library/itertools.html#itertools-recipes
32+
deque(iterator, maxlen=0)
33+
34+
35+
def next_consume(iterator):
36+
item = next(iterator)
37+
consume(iterator)
38+
return item
39+
40+
2841
@pytest.fixture(autouse=True)
2942
def test_home(mocker, tmp_path):
3043
mocker.patch("torchvision.prototype.datasets._api.home", return_value=str(tmp_path))
@@ -66,7 +79,7 @@ def test_sample(self, dataset_mock, config):
6679
dataset, _ = dataset_mock.load(config)
6780

6881
try:
69-
sample = next(iter(dataset))
82+
sample = next_consume(iter(dataset))
7083
except StopIteration:
7184
raise AssertionError("Unable to draw any sample.") from None
7285
except Exception as error:
@@ -84,22 +97,53 @@ def test_num_samples(self, dataset_mock, config):
8497

8598
assert len(list(dataset)) == mock_info["num_samples"]
8699

100+
@pytest.fixture
101+
def log_session_streams(self):
102+
debug_unclosed_streams = StreamWrapper.debug_unclosed_streams
103+
try:
104+
StreamWrapper.debug_unclosed_streams = True
105+
yield
106+
finally:
107+
StreamWrapper.debug_unclosed_streams = debug_unclosed_streams
108+
87109
@parametrize_dataset_mocks(DATASET_MOCKS)
88-
def test_no_vanilla_tensors(self, dataset_mock, config):
110+
def test_stream_closing(self, log_session_streams, dataset_mock, config):
111+
def make_msg_and_close(head):
112+
unclosed_streams = []
113+
for stream in StreamWrapper.session_streams.keys():
114+
unclosed_streams.append(repr(stream.file_obj))
115+
stream.close()
116+
unclosed_streams = "\n".join(unclosed_streams)
117+
return f"{head}\n\n{unclosed_streams}"
118+
119+
if StreamWrapper.session_streams:
120+
raise pytest.UsageError(make_msg_and_close("A previous test did not close the following streams:"))
121+
89122
dataset, _ = dataset_mock.load(config)
90123

91-
vanilla_tensors = {key for key, value in next(iter(dataset)).items() if type(value) is torch.Tensor}
92-
if vanilla_tensors:
124+
consume(iter(dataset))
125+
126+
if StreamWrapper.session_streams:
127+
raise AssertionError(make_msg_and_close("The following streams were not closed after a full iteration:"))
128+
129+
@parametrize_dataset_mocks(DATASET_MOCKS)
130+
def test_no_simple_tensors(self, dataset_mock, config):
131+
dataset, _ = dataset_mock.load(config)
132+
133+
simple_tensors = {key for key, value in next_consume(iter(dataset)).items() if features.is_simple_tensor(value)}
134+
if simple_tensors:
93135
raise AssertionError(
94136
f"The values of key(s) "
95-
f"{sequence_to_str(sorted(vanilla_tensors), separate_last='and ')} contained vanilla tensors."
137+
f"{sequence_to_str(sorted(simple_tensors), separate_last='and ')} contained simple tensors."
96138
)
97139

98140
@parametrize_dataset_mocks(DATASET_MOCKS)
99141
def test_transformable(self, dataset_mock, config):
100142
dataset, _ = dataset_mock.load(config)
101143

102-
next(iter(dataset.map(transforms.Identity())))
144+
dataset = dataset.map(transforms.Identity())
145+
146+
consume(iter(dataset))
103147

104148
@parametrize_dataset_mocks(DATASET_MOCKS)
105149
def test_traversable(self, dataset_mock, config):
@@ -131,7 +175,7 @@ def test_data_loader(self, dataset_mock, config, num_workers):
131175
collate_fn=self._collate_fn,
132176
)
133177

134-
next(iter(dl))
178+
consume(dl)
135179

136180
# TODO: we need to enforce not only that both a Shuffler and a ShardingFilter are part of the datapipe, but also
137181
# that the Shuffler comes before the ShardingFilter. Early commits in https://github.com/pytorch/vision/pull/5680
@@ -148,7 +192,7 @@ def test_has_annotations(self, dataset_mock, config, annotation_dp_type):
148192
def test_save_load(self, dataset_mock, config):
149193
dataset, _ = dataset_mock.load(config)
150194

151-
sample = next(iter(dataset))
195+
sample = next_consume(iter(dataset))
152196

153197
with io.BytesIO() as buffer:
154198
torch.save(sample, buffer)
@@ -177,7 +221,7 @@ class TestQMNIST:
177221
def test_extra_label(self, dataset_mock, config):
178222
dataset, _ = dataset_mock.load(config)
179223

180-
sample = next(iter(dataset))
224+
sample = next_consume(iter(dataset))
181225
for key, type in (
182226
("nist_hsf_series", int),
183227
("nist_writer_id", int),
@@ -214,7 +258,7 @@ def test_sample_content(self, dataset_mock, config):
214258
assert "image" in sample
215259
assert "label" in sample
216260

217-
assert isinstance(sample["image"], Image)
218-
assert isinstance(sample["label"], Label)
261+
assert isinstance(sample["image"], features.Image)
262+
assert isinstance(sample["label"], features.Label)
219263

220264
assert sample["image"].shape == (1, 16, 16)

torchvision/prototype/datasets/_builtin/celeba.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,24 +30,26 @@ def __init__(
3030

3131
def __iter__(self) -> Iterator[Tuple[str, Dict[str, str]]]:
3232
for _, file in self.datapipe:
33-
file = (line.decode() for line in file)
33+
lines = (line.decode() for line in file)
3434

3535
if self.fieldnames:
3636
fieldnames = self.fieldnames
3737
else:
3838
# The first row is skipped, because it only contains the number of samples
39-
next(file)
39+
next(lines)
4040

4141
# Empty field names are filtered out, because some files have an extra white space after the header
4242
# line, which is recognized as extra column
43-
fieldnames = [name for name in next(csv.reader([next(file)], dialect="celeba")) if name]
43+
fieldnames = [name for name in next(csv.reader([next(lines)], dialect="celeba")) if name]
4444
# Some files do not include a label for the image ID column
4545
if fieldnames[0] != "image_id":
4646
fieldnames.insert(0, "image_id")
4747

48-
for line in csv.DictReader(file, fieldnames=fieldnames, dialect="celeba"):
48+
for line in csv.DictReader(lines, fieldnames=fieldnames, dialect="celeba"):
4949
yield line.pop("image_id"), line
5050

51+
file.close()
52+
5153

5254
NAME = "celeba"
5355

torchvision/prototype/datasets/_builtin/cifar.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ def _resources(self) -> List[OnlineResource]:
6262

6363
def _unpickle(self, data: Tuple[str, io.BytesIO]) -> Dict[str, Any]:
6464
_, file = data
65-
return cast(Dict[str, Any], pickle.load(file, encoding="latin1"))
65+
content = cast(Dict[str, Any], pickle.load(file, encoding="latin1"))
66+
file.close()
67+
return content
6668

6769
def _prepare_sample(self, data: Tuple[np.ndarray, int]) -> Dict[str, Any]:
6870
image_array, category_idx = data

torchvision/prototype/datasets/_builtin/clevr.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str,
9797
buffer_size=INFINITE_BUFFER_SIZE,
9898
)
9999
else:
100+
for _, file in scenes_dp:
101+
file.close()
100102
dp = Mapper(images_dp, self._add_empty_anns)
101103

102104
return Mapper(dp, self._prepare_sample)

torchvision/prototype/datasets/_builtin/mnist.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ def __iter__(self) -> Iterator[torch.Tensor]:
5757
for _ in range(stop - start):
5858
yield read(dtype=dtype, count=count).reshape(shape)
5959

60+
file.close()
61+
6062

6163
class _MNISTBase(Dataset):
6264
_URL_BASE: Union[str, Sequence[str]]

torchvision/prototype/datasets/_builtin/pcam.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ def __iter__(self) -> Iterator[Tuple[str, io.IOBase]]:
3333
data = data[self.key]
3434
yield from data
3535

36+
handle.close()
37+
3638

3739
_Resource = namedtuple("_Resource", ("file_name", "gdrive_id", "sha256"))
3840

torchvision/prototype/datasets/_builtin/sbd.py

Lines changed: 39 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -49,31 +49,35 @@ def __init__(
4949
super().__init__(root, dependencies=("scipy",), skip_integrity_check=skip_integrity_check)
5050

5151
def _resources(self) -> List[OnlineResource]:
52-
archive = HttpResource(
53-
"https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz",
54-
sha256="6a5a2918d5c73ce032fdeba876574d150d9d04113ab87540a1304cbcc715be53",
55-
)
56-
extra_split = HttpResource(
57-
"http://home.bharathh.info/pubs/codes/SBD/train_noval.txt",
58-
sha256="0b2068f7a359d2907431803e1cd63bf6162da37d7d503b589d3b08c6fd0c2432",
59-
)
60-
return [archive, extra_split]
52+
resources = [
53+
HttpResource(
54+
"https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz",
55+
sha256="6a5a2918d5c73ce032fdeba876574d150d9d04113ab87540a1304cbcc715be53",
56+
)
57+
]
58+
if self._split == "train_noval":
59+
resources.append(
60+
HttpResource(
61+
"http://home.bharathh.info/pubs/codes/SBD/train_noval.txt",
62+
sha256="0b2068f7a359d2907431803e1cd63bf6162da37d7d503b589d3b08c6fd0c2432",
63+
)
64+
)
65+
return resources # type: ignore[return-value]
6166

6267
def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
6368
path = pathlib.Path(data[0])
6469
parent, grandparent, *_ = path.parents
6570

66-
if parent.name == "dataset":
67-
return 0
68-
elif grandparent.name == "dataset":
71+
if grandparent.name == "dataset":
6972
if parent.name == "img":
70-
return 1
73+
return 0
7174
elif parent.name == "cls":
72-
return 2
73-
else:
74-
return None
75-
else:
76-
return None
75+
return 1
76+
77+
if parent.name == "dataset" and self._split != "train_noval":
78+
return 2
79+
80+
return None
7781

7882
def _prepare_sample(self, data: Tuple[Tuple[Any, Tuple[str, BinaryIO]], Tuple[str, BinaryIO]]) -> Dict[str, Any]:
7983
split_and_image_data, ann_data = data
@@ -93,18 +97,24 @@ def _prepare_sample(self, data: Tuple[Tuple[Any, Tuple[str, BinaryIO]], Tuple[st
9397
)
9498

9599
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
96-
archive_dp, extra_split_dp = resource_dps
97-
98-
archive_dp = resource_dps[0]
99-
split_dp, images_dp, anns_dp = Demultiplexer(
100-
archive_dp,
101-
3,
102-
self._classify_archive,
103-
buffer_size=INFINITE_BUFFER_SIZE,
104-
drop_none=True,
105-
)
106100
if self._split == "train_noval":
107-
split_dp = extra_split_dp
101+
archive_dp, split_dp = resource_dps
102+
images_dp, anns_dp = Demultiplexer(
103+
archive_dp,
104+
2,
105+
self._classify_archive,
106+
buffer_size=INFINITE_BUFFER_SIZE,
107+
drop_none=True,
108+
)
109+
else:
110+
archive_dp = resource_dps[0]
111+
images_dp, anns_dp, split_dp = Demultiplexer(
112+
archive_dp,
113+
3,
114+
self._classify_archive,
115+
buffer_size=INFINITE_BUFFER_SIZE,
116+
drop_none=True,
117+
)
108118

109119
split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt"))
110120
split_dp = LineReader(split_dp, decode=True)

torchvision/prototype/datasets/_builtin/voc.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,9 @@ def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
9494
return None
9595

9696
def _parse_detection_ann(self, buffer: BinaryIO) -> Dict[str, Any]:
97-
return cast(Dict[str, Any], VOCDetection.parse_voc_xml(ElementTree.parse(buffer).getroot())["annotation"])
97+
ann = cast(Dict[str, Any], VOCDetection.parse_voc_xml(ElementTree.parse(buffer).getroot())["annotation"])
98+
buffer.close()
99+
return ann
98100

99101
def _prepare_detection_ann(self, buffer: BinaryIO) -> Dict[str, Any]:
100102
anns = self._parse_detection_ann(buffer)

torchvision/prototype/datasets/utils/_internal.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import torch.distributed as dist
99
import torch.utils.data
1010
from torchdata.datapipes.iter import IoPathFileLister, IoPathFileOpener, IterDataPipe, ShardingFilter, Shuffler
11-
from torchdata.datapipes.utils import StreamWrapper
1211
from torchvision.prototype.utils._internal import fromfile
1312

1413

@@ -40,10 +39,9 @@ def read_mat(buffer: BinaryIO, **kwargs: Any) -> Any:
4039
except ImportError as error:
4140
raise ModuleNotFoundError("Package `scipy` is required to be installed to read .mat files.") from error
4241

43-
if isinstance(buffer, StreamWrapper):
44-
buffer = buffer.file_obj
45-
46-
return sio.loadmat(buffer, **kwargs)
42+
data = sio.loadmat(buffer, **kwargs)
43+
buffer.close()
44+
return data
4745

4846

4947
class MappingIterator(IterDataPipe[Union[Tuple[K, D], D]]):

0 commit comments

Comments
 (0)