-
Notifications
You must be signed in to change notification settings - Fork 7.7k
[Data] Remove _base_dataset from StreamSplitDataIterator #61607
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
584b14c
f83c364
993cc04
eaadbf1
b94271f
4d54320
d31cd12
be1dc9f
6c70a95
e407674
017c539
41570d2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -49,18 +49,14 @@ def create( | |
| ), | ||
| ).remote(base_dataset, n, locality_hints) | ||
|
|
||
| return [ | ||
| StreamSplitDataIterator(base_dataset, coord_actor, i, n) for i in range(n) | ||
| ] | ||
| return [StreamSplitDataIterator(coord_actor, i, n) for i in range(n)] | ||
|
|
||
| def __init__( | ||
| self, | ||
| base_dataset: "Dataset", | ||
| coord_actor: ray.actor.ActorHandle, | ||
| output_split_idx: int, | ||
| world_size: int, | ||
| ): | ||
| self._base_dataset = base_dataset | ||
| self._coord_actor = coord_actor | ||
| self._output_split_idx = output_split_idx | ||
| self._world_size = world_size | ||
|
|
@@ -100,7 +96,6 @@ def gen_blocks() -> Iterator[RefBundle]: | |
| schema=block_ref_and_md.schema, | ||
| ) | ||
|
|
||
| self._base_dataset._plan._run_index += 1 | ||
| # Return None for executor since StreamSplitDataIterator has its own | ||
| # mechanism for reporting prefetched bytes via SplitCoordinator. | ||
| return gen_blocks(), self._iter_stats, False, None | ||
|
|
@@ -119,17 +114,17 @@ def stats(self) -> str: | |
|
|
||
| def schema(self) -> Union[type, "pyarrow.lib.Schema"]: | ||
|
JasonLi1909 marked this conversation as resolved.
Outdated
|
||
| """Implements DataIterator.""" | ||
| return self._base_dataset.schema() | ||
| return ray.get(self._coord_actor.get_dataset_schema.remote()) | ||
|
JasonLi1909 marked this conversation as resolved.
Outdated
cursor[bot] marked this conversation as resolved.
cursor[bot] marked this conversation as resolved.
|
||
|
|
||
| def get_context(self) -> DataContext: | ||
| return self._base_dataset.context | ||
| return ray.get(self._coord_actor.get_dataset_context.remote()) | ||
|
|
||
| def world_size(self) -> int: | ||
| """Returns the number of splits total.""" | ||
| return self._world_size | ||
|
|
||
| def _get_dataset_tag(self): | ||
| return f"{self._base_dataset.get_dataset_id()}_split_{self._output_split_idx}" | ||
| return ray.get(self._coord_actor.get_dataset_tag.remote(self._output_split_idx)) | ||
|
|
||
|
|
||
| @ray.remote(num_cpus=0) | ||
|
|
@@ -157,6 +152,8 @@ def __init__( | |
| self._n = n | ||
| self._locality_hints = locality_hints | ||
| self._lock = threading.RLock() | ||
| self._dataset_state_lock = threading.Lock() | ||
| self._schema = None | ||
| self._current_executor = None | ||
|
|
||
| # Guarded by self._lock. | ||
|
|
@@ -175,6 +172,24 @@ def __init__( | |
| # Store the error raised from the `gen_epoch` call. | ||
| self._gen_epoch_error: Optional[Exception] = None | ||
|
|
||
| def get_dataset_context(self) -> DataContext: | ||
| return self._data_context | ||
|
|
||
| def get_dataset_tag(self, output_split_idx: int) -> str: | ||
| return f"{self._base_dataset.get_dataset_id()}_split_{output_split_idx}" | ||
|
cursor[bot] marked this conversation as resolved.
|
||
|
|
||
| def get_dataset_schema(self): | ||
| with self._dataset_state_lock: | ||
| if self._current_executor is not None and self._current_executor.is_alive(): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I read the PR description but still a bit confused -- Are you implying there are scenarios where the schema is
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| raise RuntimeError( | ||
| "Cannot call schema() during active dataset execution. " | ||
|
Comment on lines
+185
to
+186
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you try running this without raising the error to see what happens? and add the output to the PR description appendix
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ran without the safe guard, it can lead to a hang because the second executor will try to schedule tasks but won't have any resources available, added to PR description.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks for checking |
||
| "Call schema() before or after iterating over the dataset, or call " | ||
| "schema() directly on the source Dataset object." | ||
| ) | ||
| if self._schema is None: | ||
| self._schema = self._base_dataset.schema() | ||
|
JasonLi1909 marked this conversation as resolved.
Outdated
|
||
| return self._schema | ||
|
cursor[bot] marked this conversation as resolved.
cursor[bot] marked this conversation as resolved.
|
||
|
|
||
| def stats(self) -> DatasetStats: | ||
| """Returns stats from the base dataset.""" | ||
| if self._current_executor: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to confirm -- this
StreamSplitDataIteratorwas being serialized 2x?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, each worker only gets one iterator so it's only serialized once. The other copy of the dataset object came from the TrainRunContext which was removed in #61953