Skip to content

Commit b63e249

Browse files
Yosua Michael Maranathafacebook-github-bot
Yosua Michael Maranatha
authored andcommitted
[fbsync] [DataPipe] Properly cleanup unclosed files within generator function (#6997)
Summary: Co-authored-by: Philip Meier <[email protected]> Reviewed By: datumbox Differential Revision: D41836894 fbshipit-source-id: b3a8c76245298f85b6ff19107bf2ef872bfb033a
1 parent f38a417 commit b63e249

File tree

3 files changed

+44
-41
lines changed

3 files changed

+44
-41
lines changed

torchvision/prototype/datasets/_builtin/celeba.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -30,25 +30,26 @@ def __init__(
3030

3131
def __iter__(self) -> Iterator[Tuple[str, Dict[str, str]]]:
3232
for _, file in self.datapipe:
33-
lines = (line.decode() for line in file)
34-
35-
if self.fieldnames:
36-
fieldnames = self.fieldnames
37-
else:
38-
# The first row is skipped, because it only contains the number of samples
39-
next(lines)
40-
41-
# Empty field names are filtered out, because some files have an extra white space after the header
42-
# line, which is recognized as extra column
43-
fieldnames = [name for name in next(csv.reader([next(lines)], dialect="celeba")) if name]
44-
# Some files do not include a label for the image ID column
45-
if fieldnames[0] != "image_id":
46-
fieldnames.insert(0, "image_id")
47-
48-
for line in csv.DictReader(lines, fieldnames=fieldnames, dialect="celeba"):
49-
yield line.pop("image_id"), line
50-
51-
file.close()
33+
try:
34+
lines = (line.decode() for line in file)
35+
36+
if self.fieldnames:
37+
fieldnames = self.fieldnames
38+
else:
39+
# The first row is skipped, because it only contains the number of samples
40+
next(lines)
41+
42+
# Empty field names are filtered out, because some files have an extra white space after the header
43+
# line, which is recognized as extra column
44+
fieldnames = [name for name in next(csv.reader([next(lines)], dialect="celeba")) if name]
45+
# Some files do not include a label for the image ID column
46+
if fieldnames[0] != "image_id":
47+
fieldnames.insert(0, "image_id")
48+
49+
for line in csv.DictReader(lines, fieldnames=fieldnames, dialect="celeba"):
50+
yield line.pop("image_id"), line
51+
finally:
52+
file.close()
5253

5354

5455
NAME = "celeba"

torchvision/prototype/datasets/_builtin/mnist.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,27 +37,28 @@ def __init__(
3737

3838
def __iter__(self) -> Iterator[torch.Tensor]:
3939
for _, file in self.datapipe:
40-
read = functools.partial(fromfile, file, byte_order="big")
40+
try:
41+
read = functools.partial(fromfile, file, byte_order="big")
4142

42-
magic = int(read(dtype=torch.int32, count=1))
43-
dtype = self._DTYPE_MAP[magic // 256]
44-
ndim = magic % 256 - 1
43+
magic = int(read(dtype=torch.int32, count=1))
44+
dtype = self._DTYPE_MAP[magic // 256]
45+
ndim = magic % 256 - 1
4546

46-
num_samples = int(read(dtype=torch.int32, count=1))
47-
shape = cast(List[int], read(dtype=torch.int32, count=ndim).tolist()) if ndim else []
48-
count = prod(shape) if shape else 1
47+
num_samples = int(read(dtype=torch.int32, count=1))
48+
shape = cast(List[int], read(dtype=torch.int32, count=ndim).tolist()) if ndim else []
49+
count = prod(shape) if shape else 1
4950

50-
start = self.start or 0
51-
stop = min(self.stop, num_samples) if self.stop else num_samples
51+
start = self.start or 0
52+
stop = min(self.stop, num_samples) if self.stop else num_samples
5253

53-
if start:
54-
num_bytes_per_value = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8
55-
file.seek(num_bytes_per_value * count * start, 1)
54+
if start:
55+
num_bytes_per_value = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8
56+
file.seek(num_bytes_per_value * count * start, 1)
5657

57-
for _ in range(stop - start):
58-
yield read(dtype=dtype, count=count).reshape(shape)
59-
60-
file.close()
58+
for _ in range(stop - start):
59+
yield read(dtype=dtype, count=count).reshape(shape)
60+
finally:
61+
file.close()
6162

6263

6364
class _MNISTBase(Dataset):

torchvision/prototype/datasets/_builtin/pcam.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,13 @@ def __iter__(self) -> Iterator[Tuple[str, io.IOBase]]:
2828
import h5py
2929

3030
for _, handle in self.datapipe:
31-
with h5py.File(handle) as data:
32-
if self.key is not None:
33-
data = data[self.key]
34-
yield from data
35-
36-
handle.close()
31+
try:
32+
with h5py.File(handle) as data:
33+
if self.key is not None:
34+
data = data[self.key]
35+
yield from data
36+
finally:
37+
handle.close()
3738

3839

3940
_Resource = namedtuple("_Resource", ("file_name", "gdrive_id", "sha256"))

0 commit comments

Comments
 (0)