-
Notifications
You must be signed in to change notification settings - Fork 170
[RFC] Retool for iterator only #1362
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/data/1362
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 1fa022f with merge base 66a17ae ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Updating based on feedback / ideation in #1363 |
|
@andrewkho is the problem you highlighted with generators also in the other samplers we already have , for eg in https://github.com/pytorch/data/blob/main/torchdata/stateful_dataloader/sampler.py#L63, https://github.com/pytorch/pytorch/blob/main/torch/utils/data/sampler.py#L288 ? |
| def reset(self, initial_state: Optional[dict] = None): | ||
| self.__initialized = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't force users to define reset?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do, in a bit of roundabout way: if users don't implement .reset() (and call super().reset()) we throw an error during next() and state_dict() calls
| def __init__(self, *args, **kwargs): | ||
| self.__initialized = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trying to understand the role of self.__initialized, when we initialize a node (and call super().init() in it), we set __initialized to false. but only when we reset we set it true
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is here to enable .reset() to contain initialization logic which users might want to be lazy initialized, because it's expensive and they eg want to load a state_dict before beginning iteration.
| if self._it is not None: | ||
| self._it._shutdown() | ||
| del self._it | ||
| self._it = _SingleThreadedMapper( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: we should try and make single threaded mapper and ParallelMapper handle reset, similar to how persistent_workers currently works. However that's going to take a more significant refactor so leaving like this for now to minimize the changes
| self._num_yielded += 1 | ||
| yield item | ||
| def next(self) -> T: | ||
| item = next(self._it) # type: ignore [arg-type, union-attr] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There might be a more-clever way to handle this, eg through a property or something, that won't make the type-hinters complain. Would prefer to do it in a separate diff as it's currently mostly sugar. Alternatively we can also assert self._it is not None to keep linters happy, but I don't like the idea of adding code to make linters happy
divyanshk
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets gooo!!!
|
Linux builds failing are unrelated |
* inital commit * remove load_state_dict * barebones test for testing * testing * loop over a single dataset * updates behaviour * update docstring * add global wrapper node, need to think of better name * add to init * Rename to Loader * Rename to Loader * working example with PR#1358 * working example with PR#1358 * Increase state_dict test coverage, fix snapshotting bug in mappers * partial * convert generators to explicit iterators * convert multi dataset generator to iterator * remove print * temp * update pin_memory * update map * fix base_node tests * update docstring * split up diff into multiple * reset _num_yielded on every iter() call * sync on top of PR #1362 * Add stopping criteria and basic validation * fix merge, remove extra prints * fix error in remote testing job * add unit test * Clean up * Address comments, add more tests, fix a bug in ALL_DATASETS_EXHAUSTED * Add seed logic for workers * Add rank based seed * unit test for rank level seeds * update naming in tests * address comments, get rid of recursive call inside next(...) * add doc strings, separate out stop critera in its own file * remove the loader.py changes * fix lints, remove the loader.py changes --------- Co-authored-by: andrewkho <[email protected]>
* inital commit * remove load_state_dict * barebones test for testing * testing * loop over a single dataset * updates behaviour * update docstring * add global wrapper node, need to think of better name * add to init * Rename to Loader * Rename to Loader * working example with PR#1358 * working example with PR#1358 * Increase state_dict test coverage, fix snapshotting bug in mappers * partial * convert generators to explicit iterators * convert multi dataset generator to iterator * remove print * temp * update pin_memory * update map * fix base_node tests * update docstring * split up diff into multiple * reset _num_yielded on every iter() call * sync on top of PR #1362 * Add stopping criteria and basic validation * fix merge, remove extra prints * fix error in remote testing job * add unit test * Clean up * Address comments, add more tests, fix a bug in ALL_DATASETS_EXHAUSTED * Add seed logic for workers * Add rank based seed * unit test for rank level seeds * update naming in tests * address comments, get rid of recursive call inside next(...) * add doc strings, separate out stop critera in its own file * remove the loader.py changes * fix lints, remove the loader.py changes --------- Co-authored-by: andrewkho <[email protected]>
* fix loader, add a unit test * Weighted Multi dataset mixer (#1361) * inital commit * remove load_state_dict * barebones test for testing * testing * loop over a single dataset * updates behaviour * update docstring * add global wrapper node, need to think of better name * add to init * Rename to Loader * Rename to Loader * working example with PR#1358 * working example with PR#1358 * Increase state_dict test coverage, fix snapshotting bug in mappers * partial * convert generators to explicit iterators * convert multi dataset generator to iterator * remove print * temp * update pin_memory * update map * fix base_node tests * update docstring * split up diff into multiple * reset _num_yielded on every iter() call * sync on top of PR #1362 * Add stopping criteria and basic validation * fix merge, remove extra prints * fix error in remote testing job * add unit test * Clean up * Address comments, add more tests, fix a bug in ALL_DATASETS_EXHAUSTED * Add seed logic for workers * Add rank based seed * unit test for rank level seeds * update naming in tests * address comments, get rid of recursive call inside next(...) * add doc strings, separate out stop critera in its own file * remove the loader.py changes * fix lints, remove the loader.py changes --------- Co-authored-by: andrewkho <[email protected]> * update variable name and add comments * Update unit tests * Update unit tests * Update unit tests * rearrage unit test method --------- Co-authored-by: andrewkho <[email protected]>
What would it look like to disallow generators?
Arrived at this PR after a couple of iterations:
BaseNodes that have .iterator() implemented as generators break down once we start to try multi-dataset sampling.
Example:
The behaviour of this node should be to mostly sample from dataset 1 (due to weight 0.99) and loop until we finally get 10 samples from dataset 0. Only once dataset 0 throws StopIteration, MultiDatasetMixer will also throw StopIteration.
What does state_dict_0 contain? It should contain the state of datasets 0 and 1, where 0 will be at the end of it's previous iteration.
Now what does state_dict_1 contain? It should contain the state of datasets 0 and 1, where 0 is almost certainly at or near the beginning of iteration. But it doesn't! It contains the state at the end of iteration. Loading this state_dict will return un-expected results.
Why is this the case? The problem is that Dataset 0's .iterator() method is defined by a generator. Even if you request another .iterator() from it, any initialization code does not get run until the first next() call occurs.
How can we deal with this and get the correct behaviour? I ran through a few scenarios and they all sound disgusting IMHO.
No Generators
Getting rid of generators has a major downside: it's a very Pythonic and way of writing and expressing iterators and we should be very wary of giving this up. Despite this massive downside, it's even more crucial for users to have a clear mental model of how Nodes should behave, and the fewer things to reason about, the easier this usually is.
Downsides of supporting generators:
Iterators only
If we're going to disallow generators, we might as well get rid of the concept of Iterables as well, for the same reason as above (No multi-iterator support). Supporting both Iterables + Iterators gives us more surface area to cover:
The main downside of getting rid of Iterables is that Iterables give us a natural method of doing lazy initialization, and reseting iteration/getting a new iterator. With Iterables, you can delay expensive initialization until you call
__iter__. With Iterators, you'd need to either make next lazy, or have the framework include some other call (we opt for the latter). With Iterables, getting a fresh iterator is as simple as calling__iter__. If there is tricky clean up or state to manage, it's easy to just throw away the old one and create a brand new Iterator. An Iterator only approach would require an explicit "reset" call or some other approach to handle multiple epochs.We opt for the following API:
Implementations need to define these three methods.