Skip to content

Conversation

@andrewkho
Copy link
Contributor

@andrewkho andrewkho commented Nov 12, 2024

What would it look like to disallow generators?

Arrived at this PR after a couple of iterations:

BaseNodes that have .iterator() implemented as generators break down once we start to try multi-dataset sampling.

Example:

node = MultiDatasetMixer(
  datasets = {0: IterableWrapper(range(10)), 1: IterableWrapper(range(10))}, 
  weights = {0: 0.01, 1: 0.99},
  stop_criteria = "LOOP_UNTIL_ALL_ITERS_EXHAUSTED",
)
results = list(node)
state_dict_0 = node.state_dict()

# Start a new epoch
for idx, val in enumerate(node):
  if idx == 5:
    break

state_dict_1 = node.state_dict()

The behaviour of this node should be to mostly sample from dataset 1 (due to weight 0.99) and loop until we finally get 10 samples from dataset 0. Only once dataset 0 throws StopIteration, MultiDatasetMixer will also throw StopIteration.

What does state_dict_0 contain? It should contain the state of datasets 0 and 1, where 0 will be at the end of it's previous iteration.

Now what does state_dict_1 contain? It should contain the state of datasets 0 and 1, where 0 is almost certainly at or near the beginning of iteration. But it doesn't! It contains the state at the end of iteration. Loading this state_dict will return un-expected results.

Why is this the case? The problem is that Dataset 0's .iterator() method is defined by a generator. Even if you request another .iterator() from it, any initialization code does not get run until the first next() call occurs.

How can we deal with this and get the correct behaviour? I ran through a few scenarios and they all sound disgusting IMHO.

  • Change IterableWrapper.iterator() to return an Iterator instead of a Generator. This doesn't work because of composability: all BaseNode subclasses would have to implement this. The only way this could work is if we "coloured" the nodes, so that certain nodes were only compatible with other nodes, or more realistically, to disallow generators entirely.
  • add a has_next() or other data caching mechanism. Major downside: you now need to hold a batch from every source in memory, as well as deal with state either by holding a copy of the previous state_dict, or by storing the data itself in the state_dict.
  • Ask generators implementations to handle re-initailization AFTER the loop is finished (before exiting .iterator()), but this is non-obvious, dependent on users doing the right thing, and also makes state management complicated to handle.

No Generators

Getting rid of generators has a major downside: it's a very Pythonic and way of writing and expressing iterators and we should be very wary of giving this up. Despite this massive downside, it's even more crucial for users to have a clear mental model of how Nodes should behave, and the fewer things to reason about, the easier this usually is.

Downsides of supporting generators:

  • Two implementation paths to reason about: We need to support both Generators and Iterators, as Iterators.
  • Explicitly Managing State: An Iterable that returns a Generator has its state managed handled implicitly by the runtime. Asking for a state_dict() from a generator isn't really possible, and instead we need to ask for state_dict() from the Iterable instance that defined the Generator. Users need to re-write their generators to use Instance-level variables
  • No multi-iterator support: Having state managed in the Iterable instance member variables will break if someone ever creates multiple iterators from the same Iterable. This is a pretty uncommon scenario in data loading, and most folks would not be doing this, however it does force framework maintainers to make assumptions and reason about these scenarios, and how to catch/block them.
  • Harder to reason about state: See above example, can we expect users and developers to reason about the current state of the iterator properly?

Iterators only

If we're going to disallow generators, we might as well get rid of the concept of Iterables as well, for the same reason as above (No multi-iterator support). Supporting both Iterables + Iterators gives us more surface area to cover:

  • How would you trace the dag through inspection? From the Iterable or the iterator?
  • From whom should you request the state, Iterable or Iterator?

The main downside of getting rid of Iterables is that Iterables give us a natural method of doing lazy initialization, and reseting iteration/getting a new iterator. With Iterables, you can delay expensive initialization until you call __iter__. With Iterators, you'd need to either make next lazy, or have the framework include some other call (we opt for the latter). With Iterables, getting a fresh iterator is as simple as calling __iter__. If there is tricky clean up or state to manage, it's easy to just throw away the old one and create a brand new Iterator. An Iterator only approach would require an explicit "reset" call or some other approach to handle multiple epochs.

We opt for the following API:

class BaseNode(Iterator[T]):
  def next(self): ... 
  def reset(self, initial_state: Optional[dict] = None): ...
  def get_state(self) -> dict: ...

Implementations need to define these three methods.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 12, 2024
@andrewkho andrewkho changed the base branch from main to andrewkh/add-global-wrapper-node November 12, 2024 06:36
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 12, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/data/1362

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 1fa022f with merge base 66a17ae (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@andrewkho
Copy link
Contributor Author

Updating based on feedback / ideation in #1363

@andrewkho andrewkho marked this pull request as ready for review November 12, 2024 23:39
@andrewkho andrewkho changed the title Retool for iterator only [RFC] Retool for iterator only Nov 12, 2024
divyanshk added a commit that referenced this pull request Nov 13, 2024
@divyanshk
Copy link
Contributor

Comment on lines +58 to +59
def reset(self, initial_state: Optional[dict] = None):
self.__initialized = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't force users to define reset?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do, in a bit of roundabout way: if users don't implement .reset() (and call super().reset()) we throw an error during next() and state_dict() calls

Comment on lines +52 to +53
def __init__(self, *args, **kwargs):
self.__initialized = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying to understand the role of self.__initialized, when we initialize a node (and call super().init() in it), we set __initialized to false. but only when we reset we set it true

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is here to enable .reset() to contain initialization logic which users might want to be lazy initialized, because it's expensive and they eg want to load a state_dict before beginning iteration.

if self._it is not None:
self._it._shutdown()
del self._it
self._it = _SingleThreadedMapper(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: we should try and make single threaded mapper and ParallelMapper handle reset, similar to how persistent_workers currently works. However that's going to take a more significant refactor so leaving like this for now to minimize the changes

self._num_yielded += 1
yield item
def next(self) -> T:
item = next(self._it) # type: ignore [arg-type, union-attr]
Copy link
Contributor Author

@andrewkho andrewkho Nov 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There might be a more-clever way to handle this, eg through a property or something, that won't make the type-hinters complain. Would prefer to do it in a separate diff as it's currently mostly sugar. Alternatively we can also assert self._it is not None to keep linters happy, but I don't like the idea of adding code to make linters happy

Copy link
Contributor

@divyanshk divyanshk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets gooo!!!

Base automatically changed from andrewkh/add-global-wrapper-node to main November 13, 2024 19:47
@andrewkho
Copy link
Contributor Author

Linux builds failing are unrelated

@andrewkho andrewkho merged commit 7fa74e4 into main Nov 13, 2024
@andrewkho andrewkho deleted the andrewkh/retool-for-iterator-only branch November 13, 2024 21:03
divyanshk added a commit that referenced this pull request Nov 17, 2024
* inital commit

* remove load_state_dict

* barebones test for testing

* testing

* loop over a single dataset

* updates behaviour

* update docstring

* add global wrapper node, need to think of better name

* add to init

* Rename to Loader

* Rename to Loader

* working example with PR#1358

* working example with PR#1358

* Increase state_dict test coverage, fix snapshotting bug in mappers

* partial

* convert generators to explicit iterators

* convert multi dataset generator to iterator

* remove print

* temp

* update pin_memory

* update map

* fix base_node tests

* update docstring

* split up diff into multiple

* reset _num_yielded on every iter() call

* sync on top of PR #1362

* Add stopping criteria and basic validation

* fix merge, remove extra prints

* fix error in remote testing job

* add unit test

* Clean up

* Address comments, add more tests, fix a bug in ALL_DATASETS_EXHAUSTED

* Add seed logic for workers

* Add rank based seed

* unit test for rank level seeds

* update naming in tests

* address comments, get rid of recursive call inside next(...)

* add doc strings, separate out stop critera in its own file

* remove the loader.py changes

* fix lints, remove the loader.py changes

---------

Co-authored-by: andrewkho <[email protected]>
divyanshk added a commit that referenced this pull request Nov 21, 2024
* inital commit

* remove load_state_dict

* barebones test for testing

* testing

* loop over a single dataset

* updates behaviour

* update docstring

* add global wrapper node, need to think of better name

* add to init

* Rename to Loader

* Rename to Loader

* working example with PR#1358

* working example with PR#1358

* Increase state_dict test coverage, fix snapshotting bug in mappers

* partial

* convert generators to explicit iterators

* convert multi dataset generator to iterator

* remove print

* temp

* update pin_memory

* update map

* fix base_node tests

* update docstring

* split up diff into multiple

* reset _num_yielded on every iter() call

* sync on top of PR #1362

* Add stopping criteria and basic validation

* fix merge, remove extra prints

* fix error in remote testing job

* add unit test

* Clean up

* Address comments, add more tests, fix a bug in ALL_DATASETS_EXHAUSTED

* Add seed logic for workers

* Add rank based seed

* unit test for rank level seeds

* update naming in tests

* address comments, get rid of recursive call inside next(...)

* add doc strings, separate out stop critera in its own file

* remove the loader.py changes

* fix lints, remove the loader.py changes

---------

Co-authored-by: andrewkho <[email protected]>
divyanshk added a commit that referenced this pull request Nov 26, 2024
* fix loader, add a unit test

* Weighted Multi dataset mixer (#1361)

* inital commit

* remove load_state_dict

* barebones test for testing

* testing

* loop over a single dataset

* updates behaviour

* update docstring

* add global wrapper node, need to think of better name

* add to init

* Rename to Loader

* Rename to Loader

* working example with PR#1358

* working example with PR#1358

* Increase state_dict test coverage, fix snapshotting bug in mappers

* partial

* convert generators to explicit iterators

* convert multi dataset generator to iterator

* remove print

* temp

* update pin_memory

* update map

* fix base_node tests

* update docstring

* split up diff into multiple

* reset _num_yielded on every iter() call

* sync on top of PR #1362

* Add stopping criteria and basic validation

* fix merge, remove extra prints

* fix error in remote testing job

* add unit test

* Clean up

* Address comments, add more tests, fix a bug in ALL_DATASETS_EXHAUSTED

* Add seed logic for workers

* Add rank based seed

* unit test for rank level seeds

* update naming in tests

* address comments, get rid of recursive call inside next(...)

* add doc strings, separate out stop critera in its own file

* remove the loader.py changes

* fix lints, remove the loader.py changes

---------

Co-authored-by: andrewkho <[email protected]>

* update variable name and add comments

* Update unit tests

* Update unit tests

* Update unit tests

* rearrage unit test method

---------

Co-authored-by: andrewkho <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants