Skip to content
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
7699381
inital commit
divyanshk Nov 8, 2024
0886fcc
remove load_state_dict
divyanshk Nov 8, 2024
d8c9bab
barebones test for testing
divyanshk Nov 8, 2024
e238b32
testing
divyanshk Nov 11, 2024
334a6a5
loop over a single dataset
divyanshk Nov 11, 2024
2afec89
updates behaviour
andrewkho Nov 11, 2024
7a1a3e3
update docstring
andrewkho Nov 11, 2024
86efb63
add global wrapper node, need to think of better name
andrewkho Nov 11, 2024
d2f0ea0
add to init
andrewkho Nov 11, 2024
02b4591
Rename to Loader
andrewkho Nov 11, 2024
f3d4449
Rename to Loader
andrewkho Nov 11, 2024
bdf0244
working example with PR#1358
divyanshk Nov 11, 2024
07becf6
working example with PR#1358
divyanshk Nov 11, 2024
8c51d8f
Merge branch 'main' into divyanshk/multi-dataset-mixer
divyanshk Nov 12, 2024
9652f4a
Increase state_dict test coverage, fix snapshotting bug in mappers
andrewkho Nov 12, 2024
09d0695
Merge branch 'andrewkh/add-global-wrapper-node' into divyanshk/multi-…
andrewkho Nov 12, 2024
55bd285
partial
andrewkho Nov 12, 2024
52ccde2
convert generators to explicit iterators
andrewkho Nov 12, 2024
25db901
Merge branch 'andrewkh/retool-for-iterator-only' into divyanshk/multi…
andrewkho Nov 12, 2024
6cb6d34
convert multi dataset generator to iterator
andrewkho Nov 12, 2024
6948a22
remove print
andrewkho Nov 12, 2024
6bfba97
temp
andrewkho Nov 12, 2024
416ffb9
update pin_memory
andrewkho Nov 12, 2024
72f5718
update map
andrewkho Nov 12, 2024
7b716b2
fix base_node tests
andrewkho Nov 12, 2024
e10a00a
update docstring
andrewkho Nov 12, 2024
7b8abad
split up diff into multiple
andrewkho Nov 12, 2024
a6cad52
add mypy ignores
andrewkho Nov 13, 2024
af610c5
merge main
andrewkho Nov 13, 2024
1fa022f
change test back to forkserver
andrewkho Nov 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions test/nodes/test_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def test_iterable(self):
n = 20
node = IterableWrapper(range(n))
for epoch in range(2):
node.reset()
result = list(node)
self.assertEqual(len(result), n)
for i, j in enumerate(result):
Expand All @@ -61,8 +62,9 @@ def test_generator(self):

def test_iterable_dataset(self):
n = 20
node = IterableWrapper(DummyIterableDataset(n))
node = IterableWrapper(DummyIterableDataset(n, name="test"))
for epoch in range(2):
node.reset()
result = list(node)
self.assertEqual(len(result), n)
for i, row in enumerate(result):
Expand All @@ -84,6 +86,7 @@ def test_default_sampler(self):
n = 20
node = MapStyleWrapper(DummyMapDataset(n), sampler=range(n))
for epoch in range(2):
node.reset()
result = list(node)
self.assertEqual(len(result), n)
for i, row in enumerate(result):
Expand All @@ -97,6 +100,7 @@ def test_random_sampler(self):
node = MapStyleWrapper(ds, sampler=RandomSampler(ds))
results = []
for epoch in range(2):
node.reset()
result = list(node)
results.append(result)
self.assertEqual(len(result), n)
Expand All @@ -116,6 +120,7 @@ def test_dict(self):
sampler = list(d.keys())
node = MapStyleWrapper(d, sampler=sampler)
for epoch in range(2):
node.reset()
result = list(node)
self.assertEqual(len(result), n)
for i, row in enumerate(result):
Expand Down Expand Up @@ -145,9 +150,10 @@ def test_sampler_wrapper(self):

results = []
for epoch in range(2):
node.reset()
self.assertEqual(node.epoch, epoch)
result = list(node)
results.append(result)
self.assertEqual(node._epoch, epoch)
self.assertEqual(len(result), n)
self.assertEqual(set(result), set(range(n)))

Expand All @@ -167,6 +173,7 @@ def test_distributed_sampler(self):
node = SamplerWrapper(sampler=sampler)

for epoch in range(4):
node.reset()
result = list(node)
self.assertEqual(result, exp[epoch])

Expand Down
16 changes: 0 additions & 16 deletions test/nodes/test_base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,10 @@

from torch.testing._internal.common_utils import TestCase
from torchdata.nodes.adapters import IterableWrapper
from torchdata.nodes.base_node import BaseNodeIterator

from .utils import run_test_save_load_state


class TestBaseNode(TestCase):
def test_started_finished(self) -> None:
x = IterableWrapper(range(10))
for _ in range(3): # test multi-epoch
it = iter(x)
self.assertIsInstance(it, BaseNodeIterator)
self.assertFalse(it.started())
self.assertFalse(it.finished())

for _ in it:
self.assertTrue(it.started())
self.assertFalse(it.finished())

self.assertTrue(it.started())
self.assertTrue(it.finished())

def test_save_load_state(self):
run_test_save_load_state(self, IterableWrapper(range(10)), 5)
9 changes: 4 additions & 5 deletions test/nodes/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
class TestMap(TestCase):
def _test_exception_handling_mapper(self, pin_memory, method):
batch_size = 6
multiprocessing_context = None if IS_WINDOWS else "forkserver"
multiprocessing_context = None if IS_WINDOWS else "fork"
src = MockSource(num_samples=20)
node = Batcher(src, batch_size=batch_size)
node = ParallelMapper(
Expand Down Expand Up @@ -58,7 +58,7 @@ def test_exception_handling_mapper_multiprocess_cuda(self):
def _test_map(self, in_order, method) -> None:
batch_size = 6
n = 80
multiprocessing_context = None if IS_WINDOWS else "forkserver"
multiprocessing_context = None if IS_WINDOWS else "fork"
src = MockSource(num_samples=n)
node = Batcher(src, batch_size=batch_size, drop_last=False)
node = ParallelMapper(
Expand All @@ -73,6 +73,7 @@ def _test_map(self, in_order, method) -> None:

results: List[List[dict]] = [[], []]
for epoch in range(2):
node.reset()
for batch in node:
results[epoch].extend(batch)

Expand Down Expand Up @@ -119,7 +120,6 @@ def test_save_load_state_thread(self, midpoint: int, in_order: bool, snapshot_fr
method = "thread"
batch_size = 6
n = 80
multiprocessing_context = None if IS_WINDOWS else "forkserver"
src = MockSource(num_samples=n)
node = Batcher(src, batch_size=batch_size, drop_last=False)
node = ParallelMapper(
Expand All @@ -128,7 +128,6 @@ def test_save_load_state_thread(self, midpoint: int, in_order: bool, snapshot_fr
num_workers=4,
in_order=in_order,
method=method,
multiprocessing_context=multiprocessing_context,
snapshot_frequency=snapshot_frequency,
)
node = Prefetcher(node, prefetch_factor=2)
Expand All @@ -145,7 +144,7 @@ def test_save_load_state_process(self, midpoint: int, in_order: bool, snapshot_f
method = "process"
batch_size = 6
n = 80
multiprocessing_context = None if IS_WINDOWS else "forkserver"
multiprocessing_context = None if IS_WINDOWS else "fork"
src = MockSource(num_samples=n)
node = Batcher(src, batch_size=batch_size, drop_last=False)
node = ParallelMapper(
Expand Down
1 change: 1 addition & 0 deletions test/nodes/test_pin_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def test_pin_memory(self) -> None:

# 2 epochs
for epoch in range(2):
root.reset()
results = list(root)
self.assertEqual(len(results), 3, epoch)
for i in range(3):
Expand Down
1 change: 1 addition & 0 deletions test/nodes/test_prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def test_prefetcher(self) -> None:

# Test multi epoch shutdown and restart
for _ in range(2):
root.reset()
results = list(root)
self.assertEqual(len(results), 3)
for i in range(3):
Expand Down
4 changes: 0 additions & 4 deletions test/nodes/test_snapshot_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,8 @@
# LICENSE file in the root directory of this source tree.

from torch.testing._internal.common_utils import TestCase
from torchdata.nodes.adapters import IterableWrapper
from torchdata.nodes.base_node import BaseNodeIterator
from torchdata.nodes.snapshot_store import DequeSnapshotStore

from .utils import run_test_save_load_state


class TestDequeSnapshotStore(TestCase):
def test_snapshot_store(self) -> None:
Expand Down
46 changes: 38 additions & 8 deletions test/nodes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import copy
import random
import time
from typing import Any, Dict, Iterator, Optional

import torch
from torchdata.nodes.adapters import IterableWrapper
from torchdata.nodes.base_node import BaseNode
from torchdata.nodes.loader import Loader


class MockGenerator:
Expand Down Expand Up @@ -50,22 +50,33 @@ def __call__(self, x):

class IterInitError(BaseNode[int]):
def __init__(self, msg: str = "Iter Init Error") -> None:
super().__init__()
self.msg = msg

def iterator(self, initial_state: Optional[Dict[str, Any]]) -> Iterator[int]:
def reset(self, initial_state: Optional[Dict[str, Any]] = None):
super().reset(initial_state)
raise ValueError(self.msg)

def next(self):
raise ValueError("next() should not be called")

def get_state(self) -> Dict[str, Any]:
return {}


class DummyIterableDataset(torch.utils.data.IterableDataset):
def __init__(self, num_samples: int) -> None:
def __init__(self, num_samples: int, name: str) -> None:
self.num_samples = num_samples
self.name = name

def __iter__(self) -> Iterator[dict]:
for i in range(self.num_samples):
yield {"step": i, "test_tensor": torch.tensor([i]), "test_str": f"str_{i}"}
yield {
"name": self.name,
"step": i,
"test_tensor": torch.tensor([i]),
"test_str": f"str_{i}",
}


class DummyMapDataset(torch.utils.data.Dataset):
Expand All @@ -79,9 +90,11 @@ def __getitem__(self, i: int) -> dict:
return {"step": i, "test_tensor": torch.tensor([i]), "test_str": f"str_{i}"}


def run_test_save_load_state(test, x: BaseNode, midpoint: int):
def run_test_save_load_state(test, node: BaseNode, midpoint: int):
##############################
# Generate initial, midpoint, and end state_dict's
x = Loader(node)

initial_state_dict = x.state_dict()
it = iter(x)
results = []
Expand All @@ -94,7 +107,16 @@ def run_test_save_load_state(test, x: BaseNode, midpoint: int):
state_dict_0_end = x.state_dict()

# store epoch 1's results
results_1 = list(x)
it = iter(x)
results_1 = []
for i in range(midpoint):
results_1.append(next(it))
state_dict_1 = x.state_dict()
for val in it:
results_1.append(val)

# for random sequences, there are no guarantees that the results will be the same length
# test.assertEqual(len(results_1), len(results))

##############################
# Test restoring from midpoint
Expand All @@ -106,6 +128,12 @@ def run_test_save_load_state(test, x: BaseNode, midpoint: int):
results_after_1 = list(x)
test.assertEqual(results_after_1, results_1)

##############################
# Test restoring from midpoint of epoch 1
x.load_state_dict(state_dict_1)
results_after_2 = list(x)
test.assertEqual(results_after_2, results_1[midpoint:])

##############################
# Test initialize from beginning after resume
x.load_state_dict(initial_state_dict)
Expand All @@ -116,12 +144,14 @@ def run_test_save_load_state(test, x: BaseNode, midpoint: int):

##############################
# Test restoring from end-of-epoch 0
x.load_state_dict(state_dict_0_end, restart_on_stop_iteration=False)
x = Loader(node, restart_on_stop_iteration=False)
x.load_state_dict(state_dict_0_end)
results_after_dict_0_with_restart_false = list(x)
test.assertEqual(results_after_dict_0_with_restart_false, [])

##############################
# Test restoring from end of epoch 0 with restart_on_stop_iteration=True
x.load_state_dict(copy.deepcopy(state_dict_0_end), restart_on_stop_iteration=True)
x = Loader(node)
x.load_state_dict(state_dict_0_end)
results_after_dict_0 = list(x)
test.assertEqual(results_after_dict_0, results_1)
2 changes: 2 additions & 0 deletions torchdata/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .adapters import IterableWrapper, MapStyleWrapper
from .base_node import BaseNode, T
from .batch import Batcher
from .loader import Loader
from .map import Mapper, ParallelMapper
from .pin_memory import PinMemory
from .prefetch import Prefetcher
Expand All @@ -17,6 +18,7 @@
"BaseNode",
"Batcher",
"IterableWrapper",
"Loader",
"MapStyleWrapper",
"Mapper",
"ParallelMapper",
Expand Down
Loading