Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 100 additions & 41 deletions python/ray/data/_internal/execution/operators/output_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ray.data._internal.execution.bundle_queue import (
HashLinkedQueue,
)
from ray._private.ray_constants import env_float
from ray.data._internal.execution.interfaces import (
ExecutionOptions,
NodeIdStr,
Expand All @@ -22,6 +23,10 @@
from ray.data.context import DataContext
from ray.types import ObjectRef

DEFAULT_OUTPUT_SPLITTER_MAX_BUFFERING_FACTOR = env_float(
"RAY_DATA_DEFAULT_OUTPUT_SPLITTER_MAX_BUFFERING_FACTOR", 2
)


class OutputSplitter(InternalQueueOperatorMixin, PhysicalOperator):
"""An operator that splits the given data into `n` output splits.
Expand Down Expand Up @@ -68,14 +73,21 @@ def __init__(
f"len({locality_hints}) != {n}"
)
self._locality_hints = locality_hints

# To optimize locality, we might defer dispatching of the bundles to allow
# for better node affinity by allowing next receiver to wait for a block
# with preferred locality (minimizing data movement).
#
# However, to guarantee liveness we cap buffering to not exceed
#
# DEFAULT_OUTPUT_SPLITTER_MAX_BUFFERING_FACTOR * N
#
# Where N is the number of outputs the sequence is being split into
if locality_hints:
# To optimize locality, we should buffer a certain number of elements
# internally before dispatch to allow the locality algorithm a good chance
# of selecting a preferred location. We use a small multiple of `n` since
# it's reasonable to buffer a couple blocks per consumer.
self._min_buffer_size = 2 * n
self._max_buffer_size = DEFAULT_OUTPUT_SPLITTER_MAX_BUFFERING_FACTOR * n
else:
self._min_buffer_size = 0
self._max_buffer_size = 0

self._locality_hits = 0
self._locality_misses = 0

Expand All @@ -92,7 +104,7 @@ def start(self, options: ExecutionOptions) -> None:
if options.preserve_order:
# If preserve_order is set, we need to ignore locality hints to ensure determinism.
self._locality_hints = None
self._min_buffer_size = 0
self._max_buffer_size = 0

super().start(options)

Expand Down Expand Up @@ -128,13 +140,24 @@ def _add_input_inner(self, bundle, input_index) -> None:
raise ValueError("OutputSplitter requires bundles with known row count")
self._buffer.add(bundle)
self._metrics.on_input_queued(bundle)
self._dispatch_bundles()
# Try dispatch buffered bundles
self._try_dispatch_bundles()

def all_inputs_done(self) -> None:
super().all_inputs_done()
if not self._equal:
self._dispatch_bundles(dispatch_all=True)
assert not self._buffer, "Should have dispatched all bundles."

# First, attempt to dispatch bundles based on the locality preferences
# (if configured)
if self._locality_hints:
# NOTE: If equal distribution is not requested, we will force
# the dispatching
self._try_dispatch_bundles(force=not self._equal)

if not self._equal:
assert not self._buffer, "All bundles should have been dispatched"
return

if not self._buffer:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing forced dispatch when equal=False without locality hints

Medium Severity

When equal=False and _locality_hints is falsy, the new all_inputs_done logic skips the explicit forced dispatch and falls through to the finalize distribution code if the buffer is non-empty. The finalize distribution code is designed only for equal=True mode and assumes the output distribution needs equalization. Running it with equal=False can trigger the assert remainder >= 0 assertion to fail since greedy dispatching produces uneven distributions where sum(allocation) may exceed buffer_size.

Fix in Cursor Fix in Web

return

# Otherwise:
Expand Down Expand Up @@ -191,57 +214,93 @@ def progress_str(self) -> str:
else:
return "[locality disabled]"

def _dispatch_bundles(self, dispatch_all: bool = False) -> None:
def _try_dispatch_bundles(self, force: bool = False) -> None:
start_time = time.perf_counter()
# Dispatch all dispatchable bundles from the internal buffer.
# This may not dispatch all bundles when equal=True.
while self._buffer and (
dispatch_all or len(self._buffer) >= self._min_buffer_size
):
target_index = self._select_output_index()
target_bundle = self._peek_bundle_to_dispatch(target_index)
if self._can_safely_dispatch(target_index, target_bundle.num_rows()):
target_bundle = self._buffer.remove(target_bundle)
self._metrics.on_input_dequeued(target_bundle)
target_bundle.output_split_idx = target_index
self._num_output[target_index] += target_bundle.num_rows()
self._output_queue.append(target_bundle)
self._metrics.on_output_queued(target_bundle)
if self._locality_hints:
preferred_loc = self._locality_hints[target_index]
if preferred_loc in self._get_locations(target_bundle):
self._locality_hits += 1
else:
self._locality_misses += 1

# Currently, there are 2 modes of operation when dispatching
# accumulated bundles:
#
# 1. Best-effort: we do a single pass over the whole buffer
# and try to dispatch all bundles either
#
# a) Based on their locality (if feasible)
# b) Longest-waiting if buffer exceeds max-size threshold
#
# 2. Mandatory: when whole buffer has to be dispatched (for ex,
# upon completion of the dataset execution)
#
for _ in range(len(self._buffer)):
# Get target output index of the next receiver
target_output_index = self._select_next_output_index()
# Look up preferred bundle
preferred_bundle = self._find_preferred_bundle(target_output_index)

if preferred_bundle:
target_bundle = preferred_bundle
elif len(self._buffer) >= self._max_buffer_size or force:
# If we're not able to find a preferred bundle and buffer size is above
# the cap, we pop the longest awaiting and pass to the next receiver
target_bundle = self._buffer.peek_next()
assert target_bundle is not None
else:
# Abort.
# Provided that we weren't able to either locate preferred bundle
# or dequeue the head one, we bail out from iteration
break

# In case, when we can't safely dispatch (to avoid violating distribution
# requirements), short-circuit
if not self._can_safely_dispatch(
target_output_index, target_bundle.num_rows()
):
break

# Pop preferred bundle from the buffer
self._buffer.remove(target_bundle)
self._metrics.on_input_dequeued(target_bundle)

target_bundle.output_split_idx = target_output_index

self._num_output[target_output_index] += target_bundle.num_rows()
self._output_queue.append(target_bundle)
self._metrics.on_output_queued(target_bundle)

if self._locality_hints:
if preferred_bundle:
self._locality_hits += 1
else:
self._locality_misses += 1

self._output_splitter_overhead_time += time.perf_counter() - start_time

def _select_output_index(self) -> int:
def _select_next_output_index(self) -> int:
# Greedily dispatch to the consumer with the least data so far.
i, _ = min(enumerate(self._num_output), key=lambda t: t[1])
return i

def _peek_bundle_to_dispatch(self, target_index: int) -> RefBundle:
def _find_preferred_bundle(self, target_output_index: int) -> Optional[RefBundle]:
if self._locality_hints:
preferred_loc = self._locality_hints[target_index]
preferred_loc = self._locality_hints[target_output_index]

# TODO make this more efficient (adding inverse hash-map)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The TODO here correctly identifies a potential performance bottleneck. The current implementation iterates through the entire buffer (O(N)) to find a bundle with preferred locality. If the buffer size becomes large, this linear scan could impact performance.

Consider implementing the suggestion in the TODO by using an inverted index (e.g., a dictionary mapping location -> List[RefBundle]) to achieve O(1) lookups for preferred bundles. This would make the locality optimization much more efficient.

for bundle in self._buffer:
if preferred_loc in self._get_locations(bundle):
return bundle

return self._buffer.peek_next()
return None

def _can_safely_dispatch(self, target_index: int, nrow: int) -> bool:
def _can_safely_dispatch(self, target_index: int, target_num_rows: int) -> bool:
if not self._equal:
# If not in equals mode, dispatch away with no buffer requirements.
return True

# Simulate dispatching a bundle to the target receiver
output_distribution = self._num_output.copy()
output_distribution[target_index] += nrow
output_distribution[target_index] += target_num_rows
buffer_requirement = self._calculate_buffer_requirement(output_distribution)
buffer_size = self._buffer.num_rows()
# Subtract target bundle size from the projected buffer
buffer_size = self._buffer.num_rows() - target_num_rows
# Check if we have enough rows LEFT after dispatching to equalize.
return buffer_size - nrow >= buffer_requirement
return buffer_size >= buffer_requirement

def _calculate_buffer_requirement(self, output_distribution: List[int]) -> int:
# Calculate the new number of rows that we'd need to equalize the row
Expand Down
Loading