Skip to content

Commit 0789218

Browse files
committed
Fixing HeaderIterDP's __len__ function
ghstack-source-id: c8feaf7 Pull Request resolved: #166
1 parent 24c25c0 commit 0789218

File tree

2 files changed

+34
-5
lines changed

2 files changed

+34
-5
lines changed

test/test_datapipe.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -307,9 +307,24 @@ def test_header_iterdatapipe(self) -> None:
307307
# __len__ Test: returns the limit when it is less than the length of source
308308
self.assertEqual(5, len(header_dp))
309309

310-
# TODO(123): __len__ Test: returns the length of source when it is less than the limit
311-
# header_dp = source_dp.header(30)
312-
# self.assertEqual(20, len(header_dp))
310+
# __len__ Test: returns the length of source when it is less than the limit
311+
header_dp = source_dp.header(30)
312+
self.assertEqual(20, len(header_dp))
313+
314+
# __len__ Test: returns limit if source doesn't have length
315+
source_dp_NoLen = IDP_NoLen(list(range(20)))
316+
header_dp = source_dp_NoLen.header(30)
317+
with warnings.catch_warnings(record=True) as wa:
318+
self.assertEqual(30, len(header_dp))
319+
self.assertEqual(len(wa), 1)
320+
self.assertRegex(
321+
str(wa[0].message), r"length of this HeaderIterDataPipe is inferred to be equal to its limit"
322+
)
323+
324+
# __len__ Test: returns limit if source doesn't have length, but it has been iterated through once
325+
for _ in header_dp:
326+
pass
327+
self.assertEqual(20, len(header_dp))
313328

314329
def test_enumerator_iterdatapipe(self) -> None:
315330
letters = "abcde"

torchdata/datapipes/iter/util/header.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Facebook, Inc. and its affiliates.
22
from typing import Iterator, TypeVar
3+
from warnings import warn
34

45
from torchdata.datapipes import functional_datapipe
56
from torchdata.datapipes.iter import IterDataPipe
@@ -20,14 +21,27 @@ class HeaderIterDataPipe(IterDataPipe[T_co]):
2021
def __init__(self, source_datapipe: IterDataPipe[T_co], limit: int = 10) -> None:
2122
self.source_datapipe: IterDataPipe[T_co] = source_datapipe
2223
self.limit: int = limit
24+
self.length = -1
2325

2426
def __iter__(self) -> Iterator[T_co]:
27+
i = -1
2528
for i, value in enumerate(self.source_datapipe):
2629
if i < self.limit:
2730
yield value
2831
else:
2932
break
33+
self.length = min(i + 1, self.limit) # We know length with certainty when we reach here
3034

31-
# TODO(134): Fix the case that the length of source_datapipe is shorter than limit
3235
def __len__(self) -> int:
33-
return self.limit
36+
if self.length != -1:
37+
return self.length
38+
try:
39+
source_len = len(self.source_datapipe)
40+
self.length = min(source_len, self.limit)
41+
return self.length
42+
except TypeError:
43+
warn(
44+
"The length of this HeaderIterDataPipe is inferred to be equal to its limit."
45+
"The actual value may be smaller if the actual length of source_datapipe is smaller than the limit."
46+
)
47+
return self.limit

0 commit comments

Comments
 (0)