Skip to content

Commit 786d388

Browse files
iamjustinhsualexeykudinkin
authored andcommitted
[Data] Revisit OutputSplitter semantic to avoid unnecessary buffer accumulation (ray-project#60237)
## Description Currently, `OutputSplitter` is only dispatching blocks that exceeds it's baseline of N * 2 (where N is the number of workers) blocks. That doesn't make a lot of sense. This change instead inverses that semantic to - Dispatch blocks to the next outstanding receiver as soon as these become available - Force dispatch in case buffer exceed it's max-size threshold (enforce buffer doesn't exceed it's max-size) ## Related issues ## Additional information --------- Signed-off-by: iamjustinhsu <jhsu@anyscale.com> Co-authored-by: Alexey Kudinkin <ak@anyscale.com> Signed-off-by: Limark Dcunha <limarkdcunha@gmail.com>
1 parent 8caa9ea commit 786d388

File tree

1 file changed

+100
-41
lines changed

1 file changed

+100
-41
lines changed

python/ray/data/_internal/execution/operators/output_splitter.py

Lines changed: 100 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from collections import deque
44
from typing import Any, Collection, Dict, List, Optional, Tuple
55

6+
from ray._private.ray_constants import env_float
67
from ray.data._internal.execution.bundle_queue import (
78
HashLinkedQueue,
89
)
@@ -22,6 +23,10 @@
2223
from ray.data.context import DataContext
2324
from ray.types import ObjectRef
2425

26+
DEFAULT_OUTPUT_SPLITTER_MAX_BUFFERING_FACTOR = env_float(
27+
"RAY_DATA_DEFAULT_OUTPUT_SPLITTER_MAX_BUFFERING_FACTOR", 2
28+
)
29+
2530

2631
class OutputSplitter(InternalQueueOperatorMixin, PhysicalOperator):
2732
"""An operator that splits the given data into `n` output splits.
@@ -68,14 +73,21 @@ def __init__(
6873
f"len({locality_hints}) != {n}"
6974
)
7075
self._locality_hints = locality_hints
76+
77+
# To optimize locality, we might defer dispatching of the bundles to allow
78+
# for better node affinity by allowing next receiver to wait for a block
79+
# with preferred locality (minimizing data movement).
80+
#
81+
# However, to guarantee liveness we cap buffering to not exceed
82+
#
83+
# DEFAULT_OUTPUT_SPLITTER_MAX_BUFFERING_FACTOR * N
84+
#
85+
# Where N is the number of outputs the sequence is being split into
7186
if locality_hints:
72-
# To optimize locality, we should buffer a certain number of elements
73-
# internally before dispatch to allow the locality algorithm a good chance
74-
# of selecting a preferred location. We use a small multiple of `n` since
75-
# it's reasonable to buffer a couple blocks per consumer.
76-
self._min_buffer_size = 2 * n
87+
self._max_buffer_size = DEFAULT_OUTPUT_SPLITTER_MAX_BUFFERING_FACTOR * n
7788
else:
78-
self._min_buffer_size = 0
89+
self._max_buffer_size = 0
90+
7991
self._locality_hits = 0
8092
self._locality_misses = 0
8193

@@ -92,7 +104,7 @@ def start(self, options: ExecutionOptions) -> None:
92104
if options.preserve_order:
93105
# If preserve_order is set, we need to ignore locality hints to ensure determinism.
94106
self._locality_hints = None
95-
self._min_buffer_size = 0
107+
self._max_buffer_size = 0
96108

97109
super().start(options)
98110

@@ -128,13 +140,24 @@ def _add_input_inner(self, bundle, input_index) -> None:
128140
raise ValueError("OutputSplitter requires bundles with known row count")
129141
self._buffer.add(bundle)
130142
self._metrics.on_input_queued(bundle)
131-
self._dispatch_bundles()
143+
# Try dispatch buffered bundles
144+
self._try_dispatch_bundles()
132145

133146
def all_inputs_done(self) -> None:
134147
super().all_inputs_done()
135-
if not self._equal:
136-
self._dispatch_bundles(dispatch_all=True)
137-
assert not self._buffer, "Should have dispatched all bundles."
148+
149+
# First, attempt to dispatch bundles based on the locality preferences
150+
# (if configured)
151+
if self._locality_hints:
152+
# NOTE: If equal distribution is not requested, we will force
153+
# the dispatching
154+
self._try_dispatch_bundles(force=not self._equal)
155+
156+
if not self._equal:
157+
assert not self._buffer, "All bundles should have been dispatched"
158+
return
159+
160+
if not self._buffer:
138161
return
139162

140163
# Otherwise:
@@ -191,57 +214,93 @@ def progress_str(self) -> str:
191214
else:
192215
return "[locality disabled]"
193216

194-
def _dispatch_bundles(self, dispatch_all: bool = False) -> None:
217+
def _try_dispatch_bundles(self, force: bool = False) -> None:
195218
start_time = time.perf_counter()
196-
# Dispatch all dispatchable bundles from the internal buffer.
197-
# This may not dispatch all bundles when equal=True.
198-
while self._buffer and (
199-
dispatch_all or len(self._buffer) >= self._min_buffer_size
200-
):
201-
target_index = self._select_output_index()
202-
target_bundle = self._peek_bundle_to_dispatch(target_index)
203-
if self._can_safely_dispatch(target_index, target_bundle.num_rows()):
204-
target_bundle = self._buffer.remove(target_bundle)
205-
self._metrics.on_input_dequeued(target_bundle)
206-
target_bundle.output_split_idx = target_index
207-
self._num_output[target_index] += target_bundle.num_rows()
208-
self._output_queue.append(target_bundle)
209-
self._metrics.on_output_queued(target_bundle)
210-
if self._locality_hints:
211-
preferred_loc = self._locality_hints[target_index]
212-
if preferred_loc in self._get_locations(target_bundle):
213-
self._locality_hits += 1
214-
else:
215-
self._locality_misses += 1
219+
220+
# Currently, there are 2 modes of operation when dispatching
221+
# accumulated bundles:
222+
#
223+
# 1. Best-effort: we do a single pass over the whole buffer
224+
# and try to dispatch all bundles either
225+
#
226+
# a) Based on their locality (if feasible)
227+
# b) Longest-waiting if buffer exceeds max-size threshold
228+
#
229+
# 2. Mandatory: when whole buffer has to be dispatched (for ex,
230+
# upon completion of the dataset execution)
231+
#
232+
for _ in range(len(self._buffer)):
233+
# Get target output index of the next receiver
234+
target_output_index = self._select_next_output_index()
235+
# Look up preferred bundle
236+
preferred_bundle = self._find_preferred_bundle(target_output_index)
237+
238+
if preferred_bundle:
239+
target_bundle = preferred_bundle
240+
elif len(self._buffer) >= self._max_buffer_size or force:
241+
# If we're not able to find a preferred bundle and buffer size is above
242+
# the cap, we pop the longest awaiting and pass to the next receiver
243+
target_bundle = self._buffer.peek_next()
244+
assert target_bundle is not None
216245
else:
217-
# Abort.
246+
# Provided that we weren't able to either locate preferred bundle
247+
# or dequeue the head one, we bail out from iteration
218248
break
249+
250+
# In case, when we can't safely dispatch (to avoid violating distribution
251+
# requirements), short-circuit
252+
if not self._can_safely_dispatch(
253+
target_output_index, target_bundle.num_rows()
254+
):
255+
break
256+
257+
# Pop preferred bundle from the buffer
258+
self._buffer.remove(target_bundle)
259+
self._metrics.on_input_dequeued(target_bundle)
260+
261+
target_bundle.output_split_idx = target_output_index
262+
263+
self._num_output[target_output_index] += target_bundle.num_rows()
264+
self._output_queue.append(target_bundle)
265+
self._metrics.on_output_queued(target_bundle)
266+
267+
if self._locality_hints:
268+
if preferred_bundle:
269+
self._locality_hits += 1
270+
else:
271+
self._locality_misses += 1
272+
219273
self._output_splitter_overhead_time += time.perf_counter() - start_time
220274

221-
def _select_output_index(self) -> int:
275+
def _select_next_output_index(self) -> int:
222276
# Greedily dispatch to the consumer with the least data so far.
223277
i, _ = min(enumerate(self._num_output), key=lambda t: t[1])
224278
return i
225279

226-
def _peek_bundle_to_dispatch(self, target_index: int) -> RefBundle:
280+
def _find_preferred_bundle(self, target_output_index: int) -> Optional[RefBundle]:
227281
if self._locality_hints:
228-
preferred_loc = self._locality_hints[target_index]
282+
preferred_loc = self._locality_hints[target_output_index]
283+
284+
# TODO make this more efficient (adding inverse hash-map)
229285
for bundle in self._buffer:
230286
if preferred_loc in self._get_locations(bundle):
231287
return bundle
232288

233-
return self._buffer.peek_next()
289+
return None
234290

235-
def _can_safely_dispatch(self, target_index: int, nrow: int) -> bool:
291+
def _can_safely_dispatch(self, target_index: int, target_num_rows: int) -> bool:
236292
if not self._equal:
237293
# If not in equals mode, dispatch away with no buffer requirements.
238294
return True
295+
296+
# Simulate dispatching a bundle to the target receiver
239297
output_distribution = self._num_output.copy()
240-
output_distribution[target_index] += nrow
298+
output_distribution[target_index] += target_num_rows
241299
buffer_requirement = self._calculate_buffer_requirement(output_distribution)
242-
buffer_size = self._buffer.num_rows()
300+
# Subtract target bundle size from the projected buffer
301+
buffer_size = self._buffer.num_rows() - target_num_rows
243302
# Check if we have enough rows LEFT after dispatching to equalize.
244-
return buffer_size - nrow >= buffer_requirement
303+
return buffer_size >= buffer_requirement
245304

246305
def _calculate_buffer_requirement(self, output_distribution: List[int]) -> int:
247306
# Calculate the new number of rows that we'd need to equalize the row

0 commit comments

Comments
 (0)