-
Notifications
You must be signed in to change notification settings - Fork 7.3k
[Data] Revisit OutputSplitter semantic to avoid unnecessary buffer accumulation #60237
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,6 +3,7 @@ | |
| from collections import deque | ||
| from typing import Any, Collection, Dict, List, Optional, Tuple | ||
|
|
||
| from ray._private.ray_constants import env_float | ||
| from ray.data._internal.execution.bundle_queue import ( | ||
| HashLinkedQueue, | ||
| ) | ||
|
|
@@ -22,6 +23,10 @@ | |
| from ray.data.context import DataContext | ||
| from ray.types import ObjectRef | ||
|
|
||
| DEFAULT_OUTPUT_SPLITTER_MAX_BUFFERING_FACTOR = env_float( | ||
| "RAY_DATA_DEFAULT_OUTPUT_SPLITTER_MAX_BUFFERING_FACTOR", 2 | ||
| ) | ||
|
|
||
|
|
||
| class OutputSplitter(InternalQueueOperatorMixin, PhysicalOperator): | ||
| """An operator that splits the given data into `n` output splits. | ||
|
|
@@ -68,14 +73,21 @@ def __init__( | |
| f"len({locality_hints}) != {n}" | ||
| ) | ||
| self._locality_hints = locality_hints | ||
|
|
||
| # To optimize locality, we might defer dispatching of the bundles to allow | ||
| # for better node affinity by allowing next receiver to wait for a block | ||
| # with preferred locality (minimizing data movement). | ||
| # | ||
| # However, to guarantee liveness we cap buffering to not exceed | ||
| # | ||
| # DEFAULT_OUTPUT_SPLITTER_MAX_BUFFERING_FACTOR * N | ||
| # | ||
| # Where N is the number of outputs the sequence is being split into | ||
| if locality_hints: | ||
| # To optimize locality, we should buffer a certain number of elements | ||
| # internally before dispatch to allow the locality algorithm a good chance | ||
| # of selecting a preferred location. We use a small multiple of `n` since | ||
| # it's reasonable to buffer a couple blocks per consumer. | ||
| self._min_buffer_size = 2 * n | ||
| self._max_buffer_size = DEFAULT_OUTPUT_SPLITTER_MAX_BUFFERING_FACTOR * n | ||
| else: | ||
| self._min_buffer_size = 0 | ||
| self._max_buffer_size = 0 | ||
|
|
||
| self._locality_hits = 0 | ||
| self._locality_misses = 0 | ||
|
|
||
|
|
@@ -92,7 +104,7 @@ def start(self, options: ExecutionOptions) -> None: | |
| if options.preserve_order: | ||
| # If preserve_order is set, we need to ignore locality hints to ensure determinism. | ||
| self._locality_hints = None | ||
| self._min_buffer_size = 0 | ||
| self._max_buffer_size = 0 | ||
|
|
||
| super().start(options) | ||
|
|
||
|
|
@@ -128,13 +140,24 @@ def _add_input_inner(self, bundle, input_index) -> None: | |
| raise ValueError("OutputSplitter requires bundles with known row count") | ||
| self._buffer.add(bundle) | ||
| self._metrics.on_input_queued(bundle) | ||
| self._dispatch_bundles() | ||
| # Try dispatch buffered bundles | ||
| self._try_dispatch_bundles() | ||
|
|
||
| def all_inputs_done(self) -> None: | ||
| super().all_inputs_done() | ||
| if not self._equal: | ||
| self._dispatch_bundles(dispatch_all=True) | ||
| assert not self._buffer, "Should have dispatched all bundles." | ||
|
|
||
| # First, attempt to dispatch bundles based on the locality preferences | ||
| # (if configured) | ||
| if self._locality_hints: | ||
| # NOTE: If equal distribution is not requested, we will force | ||
| # the dispatching | ||
| self._try_dispatch_bundles(force=not self._equal) | ||
|
|
||
| if not self._equal: | ||
| assert not self._buffer, "All bundles should have been dispatched" | ||
| return | ||
|
|
||
| if not self._buffer: | ||
| return | ||
|
|
||
| # Otherwise: | ||
|
|
@@ -191,57 +214,93 @@ def progress_str(self) -> str: | |
| else: | ||
| return "[locality disabled]" | ||
|
|
||
| def _dispatch_bundles(self, dispatch_all: bool = False) -> None: | ||
| def _try_dispatch_bundles(self, force: bool = False) -> None: | ||
| start_time = time.perf_counter() | ||
| # Dispatch all dispatchable bundles from the internal buffer. | ||
| # This may not dispatch all bundles when equal=True. | ||
| while self._buffer and ( | ||
| dispatch_all or len(self._buffer) >= self._min_buffer_size | ||
| ): | ||
| target_index = self._select_output_index() | ||
| target_bundle = self._peek_bundle_to_dispatch(target_index) | ||
| if self._can_safely_dispatch(target_index, target_bundle.num_rows()): | ||
| target_bundle = self._buffer.remove(target_bundle) | ||
| self._metrics.on_input_dequeued(target_bundle) | ||
| target_bundle.output_split_idx = target_index | ||
| self._num_output[target_index] += target_bundle.num_rows() | ||
| self._output_queue.append(target_bundle) | ||
| self._metrics.on_output_queued(target_bundle) | ||
| if self._locality_hints: | ||
| preferred_loc = self._locality_hints[target_index] | ||
| if preferred_loc in self._get_locations(target_bundle): | ||
| self._locality_hits += 1 | ||
| else: | ||
| self._locality_misses += 1 | ||
|
|
||
| # Currently, there are 2 modes of operation when dispatching | ||
| # accumulated bundles: | ||
| # | ||
| # 1. Best-effort: we do a single pass over the whole buffer | ||
| # and try to dispatch all bundles either | ||
| # | ||
| # a) Based on their locality (if feasible) | ||
| # b) Longest-waiting if buffer exceeds max-size threshold | ||
| # | ||
| # 2. Mandatory: when whole buffer has to be dispatched (for ex, | ||
| # upon completion of the dataset execution) | ||
| # | ||
| for _ in range(len(self._buffer)): | ||
| # Get target output index of the next receiver | ||
| target_output_index = self._select_next_output_index() | ||
| # Look up preferred bundle | ||
| preferred_bundle = self._find_preferred_bundle(target_output_index) | ||
|
|
||
| if preferred_bundle: | ||
| target_bundle = preferred_bundle | ||
| elif len(self._buffer) >= self._max_buffer_size or force: | ||
| # If we're not able to find a preferred bundle and buffer size is above | ||
| # the cap, we pop the longest awaiting and pass to the next receiver | ||
| target_bundle = self._buffer.peek_next() | ||
| assert target_bundle is not None | ||
| else: | ||
| # Abort. | ||
| # Provided that we weren't able to either locate preferred bundle | ||
| # or dequeue the head one, we bail out from iteration | ||
| break | ||
|
|
||
| # In case, when we can't safely dispatch (to avoid violating distribution | ||
| # requirements), short-circuit | ||
| if not self._can_safely_dispatch( | ||
| target_output_index, target_bundle.num_rows() | ||
| ): | ||
| break | ||
|
|
||
| # Pop preferred bundle from the buffer | ||
| self._buffer.remove(target_bundle) | ||
| self._metrics.on_input_dequeued(target_bundle) | ||
|
|
||
| target_bundle.output_split_idx = target_output_index | ||
|
|
||
| self._num_output[target_output_index] += target_bundle.num_rows() | ||
| self._output_queue.append(target_bundle) | ||
| self._metrics.on_output_queued(target_bundle) | ||
|
|
||
| if self._locality_hints: | ||
| if preferred_bundle: | ||
| self._locality_hits += 1 | ||
| else: | ||
| self._locality_misses += 1 | ||
|
|
||
| self._output_splitter_overhead_time += time.perf_counter() - start_time | ||
|
|
||
| def _select_output_index(self) -> int: | ||
| def _select_next_output_index(self) -> int: | ||
| # Greedily dispatch to the consumer with the least data so far. | ||
| i, _ = min(enumerate(self._num_output), key=lambda t: t[1]) | ||
| return i | ||
|
|
||
| def _peek_bundle_to_dispatch(self, target_index: int) -> RefBundle: | ||
| def _find_preferred_bundle(self, target_output_index: int) -> Optional[RefBundle]: | ||
| if self._locality_hints: | ||
| preferred_loc = self._locality_hints[target_index] | ||
| preferred_loc = self._locality_hints[target_output_index] | ||
|
|
||
| # TODO make this more efficient (adding inverse hash-map) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The Consider implementing the suggestion in the |
||
| for bundle in self._buffer: | ||
| if preferred_loc in self._get_locations(bundle): | ||
| return bundle | ||
|
|
||
| return self._buffer.peek_next() | ||
| return None | ||
|
|
||
| def _can_safely_dispatch(self, target_index: int, nrow: int) -> bool: | ||
| def _can_safely_dispatch(self, target_index: int, target_num_rows: int) -> bool: | ||
| if not self._equal: | ||
| # If not in equals mode, dispatch away with no buffer requirements. | ||
| return True | ||
|
|
||
| # Simulate dispatching a bundle to the target receiver | ||
| output_distribution = self._num_output.copy() | ||
| output_distribution[target_index] += nrow | ||
| output_distribution[target_index] += target_num_rows | ||
| buffer_requirement = self._calculate_buffer_requirement(output_distribution) | ||
| buffer_size = self._buffer.num_rows() | ||
| # Subtract target bundle size from the projected buffer | ||
| buffer_size = self._buffer.num_rows() - target_num_rows | ||
| # Check if we have enough rows LEFT after dispatching to equalize. | ||
| return buffer_size - nrow >= buffer_requirement | ||
| return buffer_size >= buffer_requirement | ||
|
|
||
| def _calculate_buffer_requirement(self, output_distribution: List[int]) -> int: | ||
| # Calculate the new number of rows that we'd need to equalize the row | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing forced dispatch when equal=False without locality hints
Medium Severity
When
equal=Falseand_locality_hintsis falsy, the newall_inputs_donelogic skips the explicit forced dispatch and falls through to the finalize distribution code if the buffer is non-empty. The finalize distribution code is designed only forequal=Truemode and assumes the output distribution needs equalization. Running it withequal=Falsecan trigger theassert remainder >= 0assertion to fail since greedy dispatching produces uneven distributions wheresum(allocation)may exceedbuffer_size.