33from collections import deque
44from typing import Any , Collection , Dict , List , Optional , Tuple
55
6+ from ray ._private .ray_constants import env_float
67from ray .data ._internal .execution .bundle_queue import (
78 HashLinkedQueue ,
89)
2223from ray .data .context import DataContext
2324from ray .types import ObjectRef
2425
26+ DEFAULT_OUTPUT_SPLITTER_MAX_BUFFERING_FACTOR = env_float (
27+ "RAY_DATA_DEFAULT_OUTPUT_SPLITTER_MAX_BUFFERING_FACTOR" , 2
28+ )
29+
2530
2631class OutputSplitter (InternalQueueOperatorMixin , PhysicalOperator ):
2732 """An operator that splits the given data into `n` output splits.
@@ -68,14 +73,21 @@ def __init__(
6873 f"len({ locality_hints } ) != { n } "
6974 )
7075 self ._locality_hints = locality_hints
76+
77+ # To optimize locality, we might defer dispatching of the bundles to allow
78+ # for better node affinity by allowing next receiver to wait for a block
79+ # with preferred locality (minimizing data movement).
80+ #
81+ # However, to guarantee liveness we cap buffering to not exceed
82+ #
83+ # DEFAULT_OUTPUT_SPLITTER_MAX_BUFFERING_FACTOR * N
84+ #
85+ # Where N is the number of outputs the sequence is being split into
7186 if locality_hints :
72- # To optimize locality, we should buffer a certain number of elements
73- # internally before dispatch to allow the locality algorithm a good chance
74- # of selecting a preferred location. We use a small multiple of `n` since
75- # it's reasonable to buffer a couple blocks per consumer.
76- self ._min_buffer_size = 2 * n
87+ self ._max_buffer_size = DEFAULT_OUTPUT_SPLITTER_MAX_BUFFERING_FACTOR * n
7788 else :
78- self ._min_buffer_size = 0
89+ self ._max_buffer_size = 0
90+
7991 self ._locality_hits = 0
8092 self ._locality_misses = 0
8193
@@ -92,7 +104,7 @@ def start(self, options: ExecutionOptions) -> None:
92104 if options .preserve_order :
93105 # If preserve_order is set, we need to ignore locality hints to ensure determinism.
94106 self ._locality_hints = None
95- self ._min_buffer_size = 0
107+ self ._max_buffer_size = 0
96108
97109 super ().start (options )
98110
@@ -128,13 +140,24 @@ def _add_input_inner(self, bundle, input_index) -> None:
128140 raise ValueError ("OutputSplitter requires bundles with known row count" )
129141 self ._buffer .add (bundle )
130142 self ._metrics .on_input_queued (bundle )
131- self ._dispatch_bundles ()
143+ # Try dispatch buffered bundles
144+ self ._try_dispatch_bundles ()
132145
133146 def all_inputs_done (self ) -> None :
134147 super ().all_inputs_done ()
135- if not self ._equal :
136- self ._dispatch_bundles (dispatch_all = True )
137- assert not self ._buffer , "Should have dispatched all bundles."
148+
149+ # First, attempt to dispatch bundles based on the locality preferences
150+ # (if configured)
151+ if self ._locality_hints :
152+ # NOTE: If equal distribution is not requested, we will force
153+ # the dispatching
154+ self ._try_dispatch_bundles (force = not self ._equal )
155+
156+ if not self ._equal :
157+ assert not self ._buffer , "All bundles should have been dispatched"
158+ return
159+
160+ if not self ._buffer :
138161 return
139162
140163 # Otherwise:
@@ -191,57 +214,93 @@ def progress_str(self) -> str:
191214 else :
192215 return "[locality disabled]"
193216
194- def _dispatch_bundles (self , dispatch_all : bool = False ) -> None :
217+ def _try_dispatch_bundles (self , force : bool = False ) -> None :
195218 start_time = time .perf_counter ()
196- # Dispatch all dispatchable bundles from the internal buffer.
197- # This may not dispatch all bundles when equal=True.
198- while self ._buffer and (
199- dispatch_all or len (self ._buffer ) >= self ._min_buffer_size
200- ):
201- target_index = self ._select_output_index ()
202- target_bundle = self ._peek_bundle_to_dispatch (target_index )
203- if self ._can_safely_dispatch (target_index , target_bundle .num_rows ()):
204- target_bundle = self ._buffer .remove (target_bundle )
205- self ._metrics .on_input_dequeued (target_bundle )
206- target_bundle .output_split_idx = target_index
207- self ._num_output [target_index ] += target_bundle .num_rows ()
208- self ._output_queue .append (target_bundle )
209- self ._metrics .on_output_queued (target_bundle )
210- if self ._locality_hints :
211- preferred_loc = self ._locality_hints [target_index ]
212- if preferred_loc in self ._get_locations (target_bundle ):
213- self ._locality_hits += 1
214- else :
215- self ._locality_misses += 1
219+
220+ # Currently, there are 2 modes of operation when dispatching
221+ # accumulated bundles:
222+ #
223+ # 1. Best-effort: we do a single pass over the whole buffer
224+ # and try to dispatch all bundles either
225+ #
226+ # a) Based on their locality (if feasible)
227+ # b) Longest-waiting if buffer exceeds max-size threshold
228+ #
229+ # 2. Mandatory: when whole buffer has to be dispatched (for ex,
230+ # upon completion of the dataset execution)
231+ #
232+ for _ in range (len (self ._buffer )):
233+ # Get target output index of the next receiver
234+ target_output_index = self ._select_next_output_index ()
235+ # Look up preferred bundle
236+ preferred_bundle = self ._find_preferred_bundle (target_output_index )
237+
238+ if preferred_bundle :
239+ target_bundle = preferred_bundle
240+ elif len (self ._buffer ) >= self ._max_buffer_size or force :
241+ # If we're not able to find a preferred bundle and buffer size is above
242+ # the cap, we pop the longest awaiting and pass to the next receiver
243+ target_bundle = self ._buffer .peek_next ()
244+ assert target_bundle is not None
216245 else :
217- # Abort.
246+ # Provided that we weren't able to either locate preferred bundle
247+ # or dequeue the head one, we bail out from iteration
218248 break
249+
250+ # In case, when we can't safely dispatch (to avoid violating distribution
251+ # requirements), short-circuit
252+ if not self ._can_safely_dispatch (
253+ target_output_index , target_bundle .num_rows ()
254+ ):
255+ break
256+
257+ # Pop preferred bundle from the buffer
258+ self ._buffer .remove (target_bundle )
259+ self ._metrics .on_input_dequeued (target_bundle )
260+
261+ target_bundle .output_split_idx = target_output_index
262+
263+ self ._num_output [target_output_index ] += target_bundle .num_rows ()
264+ self ._output_queue .append (target_bundle )
265+ self ._metrics .on_output_queued (target_bundle )
266+
267+ if self ._locality_hints :
268+ if preferred_bundle :
269+ self ._locality_hits += 1
270+ else :
271+ self ._locality_misses += 1
272+
219273 self ._output_splitter_overhead_time += time .perf_counter () - start_time
220274
221- def _select_output_index (self ) -> int :
275+ def _select_next_output_index (self ) -> int :
222276 # Greedily dispatch to the consumer with the least data so far.
223277 i , _ = min (enumerate (self ._num_output ), key = lambda t : t [1 ])
224278 return i
225279
226- def _peek_bundle_to_dispatch (self , target_index : int ) -> RefBundle :
280+ def _find_preferred_bundle (self , target_output_index : int ) -> Optional [ RefBundle ] :
227281 if self ._locality_hints :
228- preferred_loc = self ._locality_hints [target_index ]
282+ preferred_loc = self ._locality_hints [target_output_index ]
283+
284+ # TODO make this more efficient (adding inverse hash-map)
229285 for bundle in self ._buffer :
230286 if preferred_loc in self ._get_locations (bundle ):
231287 return bundle
232288
233- return self . _buffer . peek_next ()
289+ return None
234290
235- def _can_safely_dispatch (self , target_index : int , nrow : int ) -> bool :
291+ def _can_safely_dispatch (self , target_index : int , target_num_rows : int ) -> bool :
236292 if not self ._equal :
237293 # If not in equals mode, dispatch away with no buffer requirements.
238294 return True
295+
296+ # Simulate dispatching a bundle to the target receiver
239297 output_distribution = self ._num_output .copy ()
240- output_distribution [target_index ] += nrow
298+ output_distribution [target_index ] += target_num_rows
241299 buffer_requirement = self ._calculate_buffer_requirement (output_distribution )
242- buffer_size = self ._buffer .num_rows ()
300+ # Subtract target bundle size from the projected buffer
301+ buffer_size = self ._buffer .num_rows () - target_num_rows
243302 # Check if we have enough rows LEFT after dispatching to equalize.
244- return buffer_size - nrow >= buffer_requirement
303+ return buffer_size >= buffer_requirement
245304
246305 def _calculate_buffer_requirement (self , output_distribution : List [int ]) -> int :
247306 # Calculate the new number of rows that we'd need to equalize the row
0 commit comments