Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
7699381
inital commit
divyanshk Nov 8, 2024
0886fcc
remove load_state_dict
divyanshk Nov 8, 2024
d8c9bab
barebones test for testing
divyanshk Nov 8, 2024
e238b32
testing
divyanshk Nov 11, 2024
334a6a5
loop over a single dataset
divyanshk Nov 11, 2024
2afec89
updates behaviour
andrewkho Nov 11, 2024
7a1a3e3
update docstring
andrewkho Nov 11, 2024
86efb63
add global wrapper node, need to think of better name
andrewkho Nov 11, 2024
d2f0ea0
add to init
andrewkho Nov 11, 2024
02b4591
Rename to Loader
andrewkho Nov 11, 2024
f3d4449
Rename to Loader
andrewkho Nov 11, 2024
bdf0244
working example with PR#1358
divyanshk Nov 11, 2024
07becf6
working example with PR#1358
divyanshk Nov 11, 2024
8c51d8f
Merge branch 'main' into divyanshk/multi-dataset-mixer
divyanshk Nov 12, 2024
9652f4a
Increase state_dict test coverage, fix snapshotting bug in mappers
andrewkho Nov 12, 2024
09d0695
Merge branch 'andrewkh/add-global-wrapper-node' into divyanshk/multi-…
andrewkho Nov 12, 2024
55bd285
partial
andrewkho Nov 12, 2024
52ccde2
convert generators to explicit iterators
andrewkho Nov 12, 2024
25db901
Merge branch 'andrewkh/retool-for-iterator-only' into divyanshk/multi…
andrewkho Nov 12, 2024
6cb6d34
convert multi dataset generator to iterator
andrewkho Nov 12, 2024
6948a22
remove print
andrewkho Nov 12, 2024
6bfba97
temp
andrewkho Nov 12, 2024
416ffb9
update pin_memory
andrewkho Nov 12, 2024
72f5718
update map
andrewkho Nov 12, 2024
7b716b2
fix base_node tests
andrewkho Nov 12, 2024
e10a00a
update docstring
andrewkho Nov 12, 2024
7b8abad
split up diff into multiple
andrewkho Nov 12, 2024
a6cad52
add mypy ignores
andrewkho Nov 13, 2024
af610c5
merge main
andrewkho Nov 13, 2024
1fa022f
change test back to forkserver
andrewkho Nov 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions test/nodes/test_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def test_iterable(self):
n = 20
node = IterableWrapper(range(n))
for epoch in range(2):
node.reset()
result = list(node)
self.assertEqual(len(result), n)
for i, j in enumerate(result):
Expand All @@ -61,8 +62,9 @@ def test_generator(self):

def test_iterable_dataset(self):
n = 20
node = IterableWrapper(DummyIterableDataset(n))
node = IterableWrapper(DummyIterableDataset(n, name="test"))
for epoch in range(2):
node.reset()
result = list(node)
self.assertEqual(len(result), n)
for i, row in enumerate(result):
Expand All @@ -84,6 +86,7 @@ def test_default_sampler(self):
n = 20
node = MapStyleWrapper(DummyMapDataset(n), sampler=range(n))
for epoch in range(2):
node.reset()
result = list(node)
self.assertEqual(len(result), n)
for i, row in enumerate(result):
Expand All @@ -97,6 +100,7 @@ def test_random_sampler(self):
node = MapStyleWrapper(ds, sampler=RandomSampler(ds))
results = []
for epoch in range(2):
node.reset()
result = list(node)
results.append(result)
self.assertEqual(len(result), n)
Expand All @@ -116,6 +120,7 @@ def test_dict(self):
sampler = list(d.keys())
node = MapStyleWrapper(d, sampler=sampler)
for epoch in range(2):
node.reset()
result = list(node)
self.assertEqual(len(result), n)
for i, row in enumerate(result):
Expand Down Expand Up @@ -145,9 +150,10 @@ def test_sampler_wrapper(self):

results = []
for epoch in range(2):
node.reset()
self.assertEqual(node.epoch, epoch)
result = list(node)
results.append(result)
self.assertEqual(node._epoch, epoch)
self.assertEqual(len(result), n)
self.assertEqual(set(result), set(range(n)))

Expand All @@ -167,6 +173,7 @@ def test_distributed_sampler(self):
node = SamplerWrapper(sampler=sampler)

for epoch in range(4):
node.reset()
result = list(node)
self.assertEqual(result, exp[epoch])

Expand Down
16 changes: 0 additions & 16 deletions test/nodes/test_base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,10 @@

from torch.testing._internal.common_utils import TestCase
from torchdata.nodes.adapters import IterableWrapper
from torchdata.nodes.base_node import BaseNodeIterator

from .utils import run_test_save_load_state


class TestBaseNode(TestCase):
def test_started_finished(self) -> None:
x = IterableWrapper(range(10))
for _ in range(3): # test multi-epoch
it = iter(x)
self.assertIsInstance(it, BaseNodeIterator)
self.assertFalse(it.started())
self.assertFalse(it.finished())

for _ in it:
self.assertTrue(it.started())
self.assertFalse(it.finished())

self.assertTrue(it.started())
self.assertTrue(it.finished())

def test_save_load_state(self):
run_test_save_load_state(self, IterableWrapper(range(10)), 5)
3 changes: 1 addition & 2 deletions test/nodes/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def _test_map(self, in_order, method) -> None:

results: List[List[dict]] = [[], []]
for epoch in range(2):
node.reset()
for batch in node:
results[epoch].extend(batch)

Expand Down Expand Up @@ -119,7 +120,6 @@ def test_save_load_state_thread(self, midpoint: int, in_order: bool, snapshot_fr
method = "thread"
batch_size = 6
n = 80
multiprocessing_context = None if IS_WINDOWS else "forkserver"
src = MockSource(num_samples=n)
node = Batcher(src, batch_size=batch_size, drop_last=False)
node = ParallelMapper(
Expand All @@ -128,7 +128,6 @@ def test_save_load_state_thread(self, midpoint: int, in_order: bool, snapshot_fr
num_workers=4,
in_order=in_order,
method=method,
multiprocessing_context=multiprocessing_context,
snapshot_frequency=snapshot_frequency,
)
node = Prefetcher(node, prefetch_factor=2)
Expand Down
1 change: 1 addition & 0 deletions test/nodes/test_pin_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def test_pin_memory(self) -> None:

# 2 epochs
for epoch in range(2):
root.reset()
results = list(root)
self.assertEqual(len(results), 3, epoch)
for i in range(3):
Expand Down
1 change: 1 addition & 0 deletions test/nodes/test_prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def test_prefetcher(self) -> None:

# Test multi epoch shutdown and restart
for _ in range(2):
root.reset()
results = list(root)
self.assertEqual(len(results), 3)
for i in range(3):
Expand Down
4 changes: 0 additions & 4 deletions test/nodes/test_snapshot_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,8 @@
# LICENSE file in the root directory of this source tree.

from torch.testing._internal.common_utils import TestCase
from torchdata.nodes.adapters import IterableWrapper
from torchdata.nodes.base_node import BaseNodeIterator
from torchdata.nodes.snapshot_store import DequeSnapshotStore

from .utils import run_test_save_load_state


class TestDequeSnapshotStore(TestCase):
def test_snapshot_store(self) -> None:
Expand Down
21 changes: 15 additions & 6 deletions test/nodes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,22 +50,33 @@ def __call__(self, x):

class IterInitError(BaseNode[int]):
def __init__(self, msg: str = "Iter Init Error") -> None:
super().__init__()
self.msg = msg

def iterator(self, initial_state: Optional[Dict[str, Any]]) -> Iterator[int]:
def reset(self, initial_state: Optional[Dict[str, Any]] = None):
super().reset(initial_state)
raise ValueError(self.msg)

def next(self):
raise ValueError("next() should not be called")

def get_state(self) -> Dict[str, Any]:
return {}


class DummyIterableDataset(torch.utils.data.IterableDataset):
def __init__(self, num_samples: int) -> None:
def __init__(self, num_samples: int, name: str) -> None:
self.num_samples = num_samples
self.name = name

def __iter__(self) -> Iterator[dict]:
for i in range(self.num_samples):
yield {"step": i, "test_tensor": torch.tensor([i]), "test_str": f"str_{i}"}
yield {
"name": self.name,
"step": i,
"test_tensor": torch.tensor([i]),
"test_str": f"str_{i}",
}


class DummyMapDataset(torch.utils.data.Dataset):
Expand Down Expand Up @@ -104,8 +115,6 @@ def run_test_save_load_state(test, node: BaseNode, midpoint: int):
for val in it:
results_1.append(val)

assert len(results_1) == len(results)

##############################
# Test restoring from midpoint
x.load_state_dict(state_dict)
Expand All @@ -118,7 +127,7 @@ def run_test_save_load_state(test, node: BaseNode, midpoint: int):

##############################
# Test restoring from midpoint of epoch 1
x.load_state_dict(state_dict_1, restart_on_stop_iteration=True)
x.load_state_dict(state_dict_1)
results_after_2 = list(x)
test.assertEqual(results_after_2, results_1[midpoint:])

Expand Down
1 change: 1 addition & 0 deletions torchdata/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"Batcher",
"DataLoader",
"IterableWrapper",
"Loader",
"MapStyleWrapper",
"Mapper",
"ParallelMapper",
Expand Down
79 changes: 43 additions & 36 deletions torchdata/nodes/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,33 +34,37 @@ class IterableWrapper(BaseNode[T]):
ITERABLE_KEY = "iterable"

def __init__(self, iterable: Iterable[T]):
super().__init__()
self.iterable = iterable
self._num_yielded = 0
self._it: Optional[Iterator[T]] = None

def iterator(self, initial_state: Optional[Dict[str, Any]]) -> Iterator[T]:
def reset(self, initial_state: Optional[Dict[str, Any]] = None):
self._num_yielded = 0
self._it = None
super().reset(initial_state)
if initial_state is not None:
self._num_yielded = initial_state[self.NUM_YIELDED_KEY]
if isinstance(self.iterable, Stateful):
self.iterable.load_state_dict(initial_state[self.ITERABLE_KEY])
it = iter(self.iterable)
self._it = iter(self.iterable)
else:
it = iter(self.iterable)
self._it = iter(self.iterable)
# Naively fast-forwarding
for i in range(self._num_yielded):
try:
next(it)
next(self._it)
except StopIteration:
raise ValueError(
f"Tried to fast-forward {self._num_yielded} items during init but "
f"hit StopIteration after {i} items, this is likely a bug or malformed state_dict"
)
else:
it = iter(self.iterable)
self._it = iter(self.iterable)

for item in it:
self._num_yielded += 1
yield item
def next(self) -> T:
item = next(self._it) # type: ignore [arg-type, union-attr]
Copy link
Contributor Author

@andrewkho andrewkho Nov 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There might be a more-clever way to handle this, eg through a property or something, that won't make the type-hinters complain. Would prefer to do it in a separate diff as it's currently mostly sugar. Alternatively we can also assert self._it is not None to keep linters happy, but I don't like the idea of adding code to make linters happy

self._num_yielded += 1
return item

def get_state(self) -> Dict[str, Any]:
state_dict: Dict[str, Any] = {self.NUM_YIELDED_KEY: self._num_yielded}
Expand Down Expand Up @@ -92,68 +96,71 @@ class SamplerWrapper(BaseNode[T]):
:param epoch_updater: Optional[Callable[[int], int]] = None - callback to update epoch at start of new iteration. It's called at the beginning of each iterator request, except the first one.
"""

NEXT_EPOCH_KEY = "_next_epoch"
NUM_YIELDED_KEY = "_num_yielded"
SAMPLER_KEY = "sampler"
EPOCH_KEY = "_epoch"
STARTED_KEY = "_started"

@classmethod
def _default_epoch_updater(cls, epoch: int) -> int:
return epoch + 1
SAMPLER_KEY = "_sampler"

def __init__(
self,
sampler: Sampler[T],
initial_epoch: int = 0,
epoch_updater: Optional[Callable[[int], int]] = None,
):
super().__init__()
self.sampler = sampler
self.epoch_updater = epoch_updater or self._default_epoch_updater
self.epoch = initial_epoch
self._num_yielded = 0
self._epoch = initial_epoch
self._started = False
self.epoch_updater = epoch_updater or self._default_epoch_updater
self._it: Optional[Iterator[T]] = None

def iterator(self, initial_state: Optional[Dict[str, Any]]) -> Iterator[T]:
it: Iterator[T]
self._num_yielded = 0
def reset(self, initial_state: Optional[Dict[str, Any]] = None):
super().reset(initial_state)
if initial_state is not None:
self._num_yielded = initial_state[self.NUM_YIELDED_KEY]
self._epoch = initial_state[self.EPOCH_KEY]
self._started = initial_state[self.STARTED_KEY]

self.epoch = initial_state[self.EPOCH_KEY]
if isinstance(self.sampler, Stateful):
self.sampler.load_state_dict(initial_state[self.SAMPLER_KEY])
it = iter(self.sampler)
self._it = iter(self.sampler) # type: ignore [assignment]
else:
if hasattr(self.sampler, "set_epoch"):
self.sampler.set_epoch(self._epoch)
it = iter(self.sampler)
print("Setting epoch", self.epoch)
self.sampler.set_epoch(self.epoch)
self._it = iter(self.sampler)
for i in range(self._num_yielded):
try:
next(it)
next(self._it) # type: ignore [arg-type]
except StopIteration:
raise ValueError(
f"Tried to fast-forward {self._num_yielded} items during init but "
f"hit StopIteration after {i} items, this is likely a bug or malformed state_dict"
)
else:
if self._started: # don't call first time
self._epoch = self.epoch_updater(self._epoch)
self._num_yielded = 0
if self._started:
# Don't update epoch unless iterator has started
self.epoch = self.epoch_updater(self.epoch)
if hasattr(self.sampler, "set_epoch"):
self.sampler.set_epoch(self._epoch)
it = iter(self.sampler)
self.sampler.set_epoch(self.epoch)
self._it = iter(self.sampler)
self._started = False

def next(self) -> T:
self._started = True
for item in it:
self._num_yielded += 1
yield item
item = next(self._it) # type: ignore [arg-type, union-attr]
self._num_yielded += 1
return item

def get_state(self) -> Dict[str, Any]:
state_dict: Dict[str, Any] = {
self.NUM_YIELDED_KEY: self._num_yielded,
self.EPOCH_KEY: self._epoch,
self.STARTED_KEY: self._started,
self.EPOCH_KEY: self.epoch,
}
if isinstance(self.sampler, Stateful):
state_dict[self.SAMPLER_KEY] = self.sampler.state_dict()
return state_dict

@classmethod
def _default_epoch_updater(cls, epoch: int) -> int:
return epoch + 1
Loading