Skip to content

Commit 49c4562

Browse files
[Data] Streaming Partition enforce row_num per block (ray-project#57984)
## Description Currently, streaming repartition applies a map transform to each block independently and does not merge leftover rows across blocks, so it cannot guarantee exact row counts per output block. This PR introduces a new design that computes, on the driver, the input block ranges for every output block. It avoids driver-side block fetching while ensuring correctness and leveraging the efficiency of parallel map tasks. ## Related issues Closes ray-project#57165 ## Additional information --------- Signed-off-by: You-Cheng Lin (Owen) <mses010108@gmail.com> Signed-off-by: You-Cheng Lin <mses010108@gmail.com>
1 parent d15748d commit 49c4562

File tree

12 files changed

+989
-31
lines changed

12 files changed

+989
-31
lines changed

python/ray/data/_internal/execution/interfaces/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from .execution_options import ExecutionOptions, ExecutionResources
33
from .executor import Executor, OutputIterator
44
from .physical_operator import PhysicalOperator, ReportsExtraResourceUsage
5-
from .ref_bundle import RefBundle
5+
from .ref_bundle import BlockSlice, RefBundle
66
from .task_context import TaskContext
77
from .transform_fn import AllToAllTransformFn
88

@@ -15,6 +15,7 @@
1515
"OutputIterator",
1616
"PhysicalOperator",
1717
"RefBundle",
18+
"BlockSlice",
1819
"ReportsExtraResourceUsage",
1920
"TaskContext",
2021
]

python/ray/data/_internal/execution/interfaces/ref_bundle.py

Lines changed: 236 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,31 @@
1+
import itertools
2+
import math
13
from collections import defaultdict
24
from dataclasses import dataclass
3-
from typing import Dict, Iterator, List, Optional, Tuple
5+
from typing import Dict, Iterable, Iterator, List, Optional, Tuple
46

57
import ray
68
from .common import NodeIdStr
79
from ray.data._internal.memory_tracing import trace_deallocation
8-
from ray.data.block import Block, BlockMetadata, Schema
10+
from ray.data.block import Block, BlockAccessor, BlockMetadata, Schema
911
from ray.data.context import DataContext
1012
from ray.types import ObjectRef
1113

1214

15+
@dataclass
16+
class BlockSlice:
17+
"""A slice of a block."""
18+
19+
# Starting row offset (inclusive) within the block.
20+
start_offset: int
21+
# Ending row offset (exclusive) within the block.
22+
end_offset: int
23+
24+
@property
25+
def num_rows(self) -> int:
26+
return self.end_offset - self.start_offset
27+
28+
1329
@dataclass
1430
class RefBundle:
1531
"""A group of data block references and their metadata.
@@ -38,6 +54,12 @@ class RefBundle:
3854
# Whether we own the blocks (can safely destroy them).
3955
owns_blocks: bool
4056

57+
# The slices of the blocks in this bundle. After __post_init__, this is always
58+
# a list with length equal to len(blocks). Individual entries can be None to
59+
# represent a full block (equivalent to BlockSlice(0, num_rows)).
60+
# Pass None during construction to initialize all slices as None (full blocks).
61+
slices: Optional[List[Optional[BlockSlice]]] = None
62+
4163
# This attribute is used by the split() operator to assign bundles to logical
4264
# output splits. It is otherwise None.
4365
output_split_idx: Optional[int] = None
@@ -53,6 +75,27 @@ class RefBundle:
5375
def __post_init__(self):
5476
if not isinstance(self.blocks, tuple):
5577
object.__setattr__(self, "blocks", tuple(self.blocks))
78+
79+
if self.slices is None:
80+
self.slices = [None] * len(self.blocks)
81+
else:
82+
assert len(self.blocks) == len(
83+
self.slices
84+
), "Number of blocks and slices must match"
85+
# Validate slice ranges
86+
for (_, metadata), block_slice in zip(self.blocks, self.slices):
87+
if block_slice is not None:
88+
assert (
89+
block_slice.start_offset >= 0
90+
), f"Slice start_offset must be non-negative: {block_slice.start_offset}"
91+
assert (
92+
block_slice.end_offset >= block_slice.start_offset
93+
), f"Slice end_offset must be >= start_offset: [{block_slice.start_offset}, {block_slice.end_offset})"
94+
if metadata.num_rows is not None:
95+
assert (
96+
block_slice.end_offset <= metadata.num_rows
97+
), f"Slice range [{block_slice.start_offset}, {block_slice.end_offset}) exceeds block num_rows: {metadata.num_rows}"
98+
5699
for b in self.blocks:
57100
assert isinstance(b, tuple), b
58101
assert len(b) == 2, b
@@ -79,18 +122,52 @@ def metadata(self) -> List[BlockMetadata]:
79122
return [metadata for _, metadata in self.blocks]
80123

81124
def num_rows(self) -> Optional[int]:
82-
"""Number of rows present in this bundle, if known."""
125+
"""Number of rows present in this bundle, if known.
126+
127+
Iterates through blocks and their corresponding slices to calculate the total.
128+
Note: Block metadata always refers to the full block, not the slice.
129+
130+
- If block_slice is None, uses the full block's metadata.num_rows
131+
- If block_slice is present, uses the slice's num_rows (partial block portion)
132+
- Returns None if any full block has unknown row count (metadata.num_rows is None)
133+
"""
83134
total = 0
84-
for m in self.metadata:
85-
if m.num_rows is None:
86-
return None
135+
for metadata, block_slice in zip(self.metadata, self.slices):
136+
if block_slice is None:
137+
if metadata.num_rows is None:
138+
return None
139+
total += metadata.num_rows
87140
else:
88-
total += m.num_rows
141+
total += block_slice.num_rows
89142
return total
90143

91144
def size_bytes(self) -> int:
92-
"""Size of the blocks of this bundle in bytes."""
93-
return sum(m.size_bytes for m in self.metadata)
145+
"""Size of the blocks of this bundle in bytes.
146+
147+
Iterates through blocks and their corresponding slices to calculate the total size.
148+
Note: Block metadata always refers to the full block, not the slice.
149+
150+
- If block_slice is None, uses the full block's metadata.size_bytes
151+
- If block_slice is present but num_rows is unknown or zero, uses full metadata.size_bytes
152+
- If block_slice represents a partial block, estimates size proportionally based on
153+
(metadata.size_bytes / metadata.num_rows) * block_slice.num_rows
154+
- Otherwise, uses the full metadata.size_bytes
155+
"""
156+
total = 0
157+
for (_, metadata), block_slice in zip(self.blocks, self.slices):
158+
if block_slice is None:
159+
# Full block
160+
total += metadata.size_bytes
161+
elif metadata.num_rows is None or metadata.num_rows == 0:
162+
# Unknown num_rows or empty block - use full metadata size
163+
total += metadata.size_bytes
164+
elif metadata.num_rows != block_slice.num_rows:
165+
# Partial block - estimate size based on rows
166+
per_row = metadata.size_bytes / metadata.num_rows
167+
total += max(1, int(math.ceil(per_row * block_slice.num_rows)))
168+
else:
169+
total += metadata.size_bytes
170+
return total
94171

95172
def destroy_if_owned(self) -> int:
96173
"""Clears the object store memory for these blocks if owned.
@@ -143,6 +220,102 @@ def _get_cached_metadata(self) -> Dict[ObjectRef, "_ObjectMetadata"]:
143220

144221
return self._cached_object_meta
145222

223+
def slice(self, needed_rows: int) -> Tuple["RefBundle", "RefBundle"]:
224+
"""Slice a Ref Bundle into the first bundle containing the first `needed_rows` rows and the remaining bundle containing the remaining rows.
225+
226+
Args:
227+
needed_rows: Number of rows to take from the head of the bundle.
228+
229+
Returns:
230+
A tuple of (sliced_bundle, remaining_bundle). The needed rows must be less than the number of rows in the bundle.
231+
"""
232+
assert needed_rows > 0, "needed_rows must be positive."
233+
assert (
234+
self.num_rows() is not None
235+
), "Cannot slice a RefBundle with unknown number of rows."
236+
assert (
237+
needed_rows < self.num_rows()
238+
), f"To slice a RefBundle, the number of requested rows must be less than the number of rows in the bundle. Requested {needed_rows} rows but bundle only has {self.num_rows()} rows."
239+
240+
block_slices = []
241+
for metadata, block_slice in zip(self.metadata, self.slices):
242+
if block_slice is None:
243+
# None represents a full block, convert to explicit BlockSlice
244+
assert (
245+
metadata.num_rows is not None
246+
), "Cannot derive block slice for a RefBundle with unknown block row counts."
247+
block_slices.append(
248+
BlockSlice(start_offset=0, end_offset=metadata.num_rows)
249+
)
250+
else:
251+
block_slices.append(block_slice)
252+
253+
consumed_blocks: List[Tuple[ObjectRef[Block], BlockMetadata]] = []
254+
consumed_slices: List[BlockSlice] = []
255+
remaining_blocks: List[Tuple[ObjectRef[Block], BlockMetadata]] = []
256+
remaining_slices: List[BlockSlice] = []
257+
258+
rows_to_take = needed_rows
259+
260+
for (block_ref, metadata), block_slice in zip(self.blocks, block_slices):
261+
block_rows = block_slice.num_rows
262+
if rows_to_take >= block_rows:
263+
consumed_blocks.append((block_ref, metadata))
264+
consumed_slices.append(block_slice)
265+
rows_to_take -= block_rows
266+
else:
267+
if rows_to_take == 0:
268+
remaining_blocks.append((block_ref, metadata))
269+
remaining_slices.append(block_slice)
270+
continue
271+
consume_slice = BlockSlice(
272+
start_offset=block_slice.start_offset,
273+
end_offset=block_slice.start_offset + rows_to_take,
274+
)
275+
consumed_blocks.append((block_ref, metadata))
276+
consumed_slices.append(consume_slice)
277+
278+
leftover_rows = block_rows - rows_to_take
279+
if leftover_rows > 0:
280+
remainder_slice = BlockSlice(
281+
start_offset=consume_slice.end_offset,
282+
end_offset=block_slice.end_offset,
283+
)
284+
remaining_blocks.append((block_ref, metadata))
285+
remaining_slices.append(remainder_slice)
286+
287+
rows_to_take = 0
288+
289+
sliced_bundle = RefBundle(
290+
blocks=tuple(consumed_blocks),
291+
schema=self.schema,
292+
owns_blocks=False,
293+
slices=consumed_slices if consumed_slices else None,
294+
)
295+
296+
remaining_bundle = RefBundle(
297+
blocks=tuple(remaining_blocks),
298+
schema=self.schema,
299+
owns_blocks=False,
300+
slices=remaining_slices if remaining_slices else None,
301+
)
302+
303+
return sliced_bundle, remaining_bundle
304+
305+
@classmethod
306+
def merge_ref_bundles(cls, bundles: List["RefBundle"]) -> "RefBundle":
307+
assert bundles, "Cannot merge an empty list of RefBundles."
308+
merged_blocks = list(itertools.chain(*[bundle.blocks for bundle in bundles]))
309+
merged_slices = list(itertools.chain(*[bundle.slices for bundle in bundles]))
310+
return cls(
311+
blocks=tuple(merged_blocks),
312+
schema=bundles[0].schema, # Assume all bundles have the same schema
313+
owns_blocks=bundles[
314+
0
315+
].owns_blocks, # Assume all bundles have the same ownership
316+
slices=merged_slices,
317+
)
318+
146319
def __eq__(self, other) -> bool:
147320
return self is other
148321

@@ -152,6 +325,38 @@ def __hash__(self) -> int:
152325
def __len__(self) -> int:
153326
return len(self.blocks)
154327

328+
def __str__(self) -> str:
329+
lines = [
330+
f"RefBundle({len(self.blocks)} blocks,",
331+
f" {self.num_rows()} rows,",
332+
f" schema={self.schema},",
333+
f" owns_blocks={self.owns_blocks},",
334+
" blocks=(",
335+
]
336+
337+
# Loop through each block and show details
338+
for i, ((block_ref, metadata), block_slice) in enumerate(
339+
zip(self.blocks, self.slices)
340+
):
341+
row_str = (
342+
f"{metadata.num_rows} rows"
343+
if metadata.num_rows is not None
344+
else "unknown rows"
345+
)
346+
bytes_str = f"{metadata.size_bytes} bytes"
347+
slice_str = (
348+
f"slice={block_slice}"
349+
if block_slice is not None
350+
else "slice=None (full block)"
351+
)
352+
353+
lines.append(f" {i}: {row_str}, {bytes_str}, {slice_str}")
354+
355+
lines.append(" )")
356+
lines.append(")")
357+
358+
return "\n".join(lines)
359+
155360

156361
@dataclass
157362
class _ObjectMetadata:
@@ -170,3 +375,25 @@ def _ref_bundles_iterator_to_block_refs_list(
170375
return [
171376
block_ref for ref_bundle in ref_bundles for block_ref in ref_bundle.block_refs
172377
]
378+
379+
380+
def _iter_sliced_blocks(
381+
blocks: Iterable[Block],
382+
slices: List[Optional[BlockSlice]],
383+
) -> Iterator[Block]:
384+
blocks_list = list(blocks)
385+
for block, block_slice in zip(blocks_list, slices):
386+
if block_slice is None:
387+
# None represents a full block - yield it as is
388+
yield block
389+
else:
390+
accessor = BlockAccessor.for_block(block)
391+
start = block_slice.start_offset
392+
end = block_slice.end_offset
393+
assert start <= end, "start must be less than end"
394+
assert start >= 0, "start must be non-negative"
395+
assert (
396+
end <= accessor.num_rows()
397+
), "end must be less than or equal to the number of rows in the block"
398+
399+
yield accessor.slice(start, end, copy=False)

python/ray/data/_internal/execution/operators/actor_pool_map_operator.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from ray.data._internal.execution.bundle_queue import create_bundle_queue
2121
from ray.data._internal.execution.bundle_queue.bundle_queue import BundleQueue
2222
from ray.data._internal.execution.interfaces import (
23+
BlockSlice,
2324
ExecutionOptions,
2425
ExecutionResources,
2526
NodeIdStr,
@@ -32,7 +33,11 @@
3233
ActorLocationTracker,
3334
get_or_create_actor_location_tracker,
3435
)
35-
from ray.data._internal.execution.operators.map_operator import MapOperator, _map_task
36+
from ray.data._internal.execution.operators.map_operator import (
37+
BaseRefBundler,
38+
MapOperator,
39+
_map_task,
40+
)
3641
from ray.data._internal.execution.operators.map_transformer import MapTransformer
3742
from ray.data._internal.execution.util import locality_string
3843
from ray.data._internal.remote_fn import _add_system_error_to_retry_exceptions
@@ -71,6 +76,7 @@ def __init__(
7176
compute_strategy: ActorPoolStrategy,
7277
name: str = "ActorPoolMap",
7378
min_rows_per_bundle: Optional[int] = None,
79+
ref_bundler: Optional[BaseRefBundler] = None,
7480
supports_fusion: bool = True,
7581
map_task_kwargs: Optional[Dict[str, Any]] = None,
7682
ray_remote_args_fn: Optional[Callable[[], Dict[str, Any]]] = None,
@@ -91,6 +97,7 @@ def __init__(
9197
transform_fn, or None to use the block size. Setting the batch size is
9298
important for the performance of GPU-accelerated transform functions.
9399
The actual rows passed may be less if the dataset is small.
100+
ref_bundler: The ref bundler to use for this operator.
94101
supports_fusion: Whether this operator supports fusion with other operators.
95102
map_task_kwargs: A dictionary of kwargs to pass to the map task. You can
96103
access these kwargs through the `TaskContext.kwargs` dictionary.
@@ -113,6 +120,7 @@ def __init__(
113120
name,
114121
target_max_block_size_override,
115122
min_rows_per_bundle,
123+
ref_bundler,
116124
supports_fusion,
117125
map_task_kwargs,
118126
ray_remote_args_fn,
@@ -330,6 +338,7 @@ def _dispatch_tasks(self):
330338
self.data_context,
331339
ctx,
332340
*input_blocks,
341+
slices=bundle.slices,
333342
**self.get_map_task_kwargs(),
334343
)
335344

@@ -571,13 +580,15 @@ def submit(
571580
data_context: DataContext,
572581
ctx: TaskContext,
573582
*blocks: Block,
583+
slices: Optional[List[BlockSlice]] = None,
574584
**kwargs: Dict[str, Any],
575585
) -> Iterator[Union[Block, List[BlockMetadata]]]:
576586
yield from _map_task(
577587
self._map_transformer,
578588
data_context,
579589
ctx,
580590
*blocks,
591+
slices=slices,
581592
**kwargs,
582593
)
583594

0 commit comments

Comments
 (0)