1+ import itertools
2+ import math
13from collections import defaultdict
24from dataclasses import dataclass
3- from typing import Dict , Iterator , List , Optional , Tuple
5+ from typing import Dict , Iterable , Iterator , List , Optional , Tuple
46
57import ray
68from .common import NodeIdStr
79from ray .data ._internal .memory_tracing import trace_deallocation
8- from ray .data .block import Block , BlockMetadata , Schema
10+ from ray .data .block import Block , BlockAccessor , BlockMetadata , Schema
911from ray .data .context import DataContext
1012from ray .types import ObjectRef
1113
1214
15+ @dataclass
16+ class BlockSlice :
17+ """A slice of a block."""
18+
19+ # Starting row offset (inclusive) within the block.
20+ start_offset : int
21+ # Ending row offset (exclusive) within the block.
22+ end_offset : int
23+
24+ @property
25+ def num_rows (self ) -> int :
26+ return self .end_offset - self .start_offset
27+
28+
1329@dataclass
1430class RefBundle :
1531 """A group of data block references and their metadata.
@@ -38,6 +54,12 @@ class RefBundle:
3854 # Whether we own the blocks (can safely destroy them).
3955 owns_blocks : bool
4056
57+ # The slices of the blocks in this bundle. After __post_init__, this is always
58+ # a list with length equal to len(blocks). Individual entries can be None to
59+ # represent a full block (equivalent to BlockSlice(0, num_rows)).
60+ # Pass None during construction to initialize all slices as None (full blocks).
61+ slices : Optional [List [Optional [BlockSlice ]]] = None
62+
4163 # This attribute is used by the split() operator to assign bundles to logical
4264 # output splits. It is otherwise None.
4365 output_split_idx : Optional [int ] = None
@@ -53,6 +75,27 @@ class RefBundle:
5375 def __post_init__ (self ):
5476 if not isinstance (self .blocks , tuple ):
5577 object .__setattr__ (self , "blocks" , tuple (self .blocks ))
78+
79+ if self .slices is None :
80+ self .slices = [None ] * len (self .blocks )
81+ else :
82+ assert len (self .blocks ) == len (
83+ self .slices
84+ ), "Number of blocks and slices must match"
85+ # Validate slice ranges
86+ for (_ , metadata ), block_slice in zip (self .blocks , self .slices ):
87+ if block_slice is not None :
88+ assert (
89+ block_slice .start_offset >= 0
90+ ), f"Slice start_offset must be non-negative: { block_slice .start_offset } "
91+ assert (
92+ block_slice .end_offset >= block_slice .start_offset
93+ ), f"Slice end_offset must be >= start_offset: [{ block_slice .start_offset } , { block_slice .end_offset } )"
94+ if metadata .num_rows is not None :
95+ assert (
96+ block_slice .end_offset <= metadata .num_rows
97+ ), f"Slice range [{ block_slice .start_offset } , { block_slice .end_offset } ) exceeds block num_rows: { metadata .num_rows } "
98+
5699 for b in self .blocks :
57100 assert isinstance (b , tuple ), b
58101 assert len (b ) == 2 , b
@@ -79,18 +122,52 @@ def metadata(self) -> List[BlockMetadata]:
79122 return [metadata for _ , metadata in self .blocks ]
80123
81124 def num_rows (self ) -> Optional [int ]:
82- """Number of rows present in this bundle, if known."""
125+ """Number of rows present in this bundle, if known.
126+
127+ Iterates through blocks and their corresponding slices to calculate the total.
128+ Note: Block metadata always refers to the full block, not the slice.
129+
130+ - If block_slice is None, uses the full block's metadata.num_rows
131+ - If block_slice is present, uses the slice's num_rows (partial block portion)
132+ - Returns None if any full block has unknown row count (metadata.num_rows is None)
133+ """
83134 total = 0
84- for m in self .metadata :
85- if m .num_rows is None :
86- return None
135+ for metadata , block_slice in zip (self .metadata , self .slices ):
136+ if block_slice is None :
137+ if metadata .num_rows is None :
138+ return None
139+ total += metadata .num_rows
87140 else :
88- total += m .num_rows
141+ total += block_slice .num_rows
89142 return total
90143
91144 def size_bytes (self ) -> int :
92- """Size of the blocks of this bundle in bytes."""
93- return sum (m .size_bytes for m in self .metadata )
145+ """Size of the blocks of this bundle in bytes.
146+
147+ Iterates through blocks and their corresponding slices to calculate the total size.
148+ Note: Block metadata always refers to the full block, not the slice.
149+
150+ - If block_slice is None, uses the full block's metadata.size_bytes
151+ - If block_slice is present but num_rows is unknown or zero, uses full metadata.size_bytes
152+ - If block_slice represents a partial block, estimates size proportionally based on
153+ (metadata.size_bytes / metadata.num_rows) * block_slice.num_rows
154+ - Otherwise, uses the full metadata.size_bytes
155+ """
156+ total = 0
157+ for (_ , metadata ), block_slice in zip (self .blocks , self .slices ):
158+ if block_slice is None :
159+ # Full block
160+ total += metadata .size_bytes
161+ elif metadata .num_rows is None or metadata .num_rows == 0 :
162+ # Unknown num_rows or empty block - use full metadata size
163+ total += metadata .size_bytes
164+ elif metadata .num_rows != block_slice .num_rows :
165+ # Partial block - estimate size based on rows
166+ per_row = metadata .size_bytes / metadata .num_rows
167+ total += max (1 , int (math .ceil (per_row * block_slice .num_rows )))
168+ else :
169+ total += metadata .size_bytes
170+ return total
94171
95172 def destroy_if_owned (self ) -> int :
96173 """Clears the object store memory for these blocks if owned.
@@ -143,6 +220,102 @@ def _get_cached_metadata(self) -> Dict[ObjectRef, "_ObjectMetadata"]:
143220
144221 return self ._cached_object_meta
145222
223+ def slice (self , needed_rows : int ) -> Tuple ["RefBundle" , "RefBundle" ]:
224+ """Slice a Ref Bundle into the first bundle containing the first `needed_rows` rows and the remaining bundle containing the remaining rows.
225+
226+ Args:
227+ needed_rows: Number of rows to take from the head of the bundle.
228+
229+ Returns:
230+ A tuple of (sliced_bundle, remaining_bundle). The needed rows must be less than the number of rows in the bundle.
231+ """
232+ assert needed_rows > 0 , "needed_rows must be positive."
233+ assert (
234+ self .num_rows () is not None
235+ ), "Cannot slice a RefBundle with unknown number of rows."
236+ assert (
237+ needed_rows < self .num_rows ()
238+ ), f"To slice a RefBundle, the number of requested rows must be less than the number of rows in the bundle. Requested { needed_rows } rows but bundle only has { self .num_rows ()} rows."
239+
240+ block_slices = []
241+ for metadata , block_slice in zip (self .metadata , self .slices ):
242+ if block_slice is None :
243+ # None represents a full block, convert to explicit BlockSlice
244+ assert (
245+ metadata .num_rows is not None
246+ ), "Cannot derive block slice for a RefBundle with unknown block row counts."
247+ block_slices .append (
248+ BlockSlice (start_offset = 0 , end_offset = metadata .num_rows )
249+ )
250+ else :
251+ block_slices .append (block_slice )
252+
253+ consumed_blocks : List [Tuple [ObjectRef [Block ], BlockMetadata ]] = []
254+ consumed_slices : List [BlockSlice ] = []
255+ remaining_blocks : List [Tuple [ObjectRef [Block ], BlockMetadata ]] = []
256+ remaining_slices : List [BlockSlice ] = []
257+
258+ rows_to_take = needed_rows
259+
260+ for (block_ref , metadata ), block_slice in zip (self .blocks , block_slices ):
261+ block_rows = block_slice .num_rows
262+ if rows_to_take >= block_rows :
263+ consumed_blocks .append ((block_ref , metadata ))
264+ consumed_slices .append (block_slice )
265+ rows_to_take -= block_rows
266+ else :
267+ if rows_to_take == 0 :
268+ remaining_blocks .append ((block_ref , metadata ))
269+ remaining_slices .append (block_slice )
270+ continue
271+ consume_slice = BlockSlice (
272+ start_offset = block_slice .start_offset ,
273+ end_offset = block_slice .start_offset + rows_to_take ,
274+ )
275+ consumed_blocks .append ((block_ref , metadata ))
276+ consumed_slices .append (consume_slice )
277+
278+ leftover_rows = block_rows - rows_to_take
279+ if leftover_rows > 0 :
280+ remainder_slice = BlockSlice (
281+ start_offset = consume_slice .end_offset ,
282+ end_offset = block_slice .end_offset ,
283+ )
284+ remaining_blocks .append ((block_ref , metadata ))
285+ remaining_slices .append (remainder_slice )
286+
287+ rows_to_take = 0
288+
289+ sliced_bundle = RefBundle (
290+ blocks = tuple (consumed_blocks ),
291+ schema = self .schema ,
292+ owns_blocks = False ,
293+ slices = consumed_slices if consumed_slices else None ,
294+ )
295+
296+ remaining_bundle = RefBundle (
297+ blocks = tuple (remaining_blocks ),
298+ schema = self .schema ,
299+ owns_blocks = False ,
300+ slices = remaining_slices if remaining_slices else None ,
301+ )
302+
303+ return sliced_bundle , remaining_bundle
304+
305+ @classmethod
306+ def merge_ref_bundles (cls , bundles : List ["RefBundle" ]) -> "RefBundle" :
307+ assert bundles , "Cannot merge an empty list of RefBundles."
308+ merged_blocks = list (itertools .chain (* [bundle .blocks for bundle in bundles ]))
309+ merged_slices = list (itertools .chain (* [bundle .slices for bundle in bundles ]))
310+ return cls (
311+ blocks = tuple (merged_blocks ),
312+ schema = bundles [0 ].schema , # Assume all bundles have the same schema
313+ owns_blocks = bundles [
314+ 0
315+ ].owns_blocks , # Assume all bundles have the same ownership
316+ slices = merged_slices ,
317+ )
318+
146319 def __eq__ (self , other ) -> bool :
147320 return self is other
148321
@@ -152,6 +325,38 @@ def __hash__(self) -> int:
152325 def __len__ (self ) -> int :
153326 return len (self .blocks )
154327
328+ def __str__ (self ) -> str :
329+ lines = [
330+ f"RefBundle({ len (self .blocks )} blocks," ,
331+ f" { self .num_rows ()} rows," ,
332+ f" schema={ self .schema } ," ,
333+ f" owns_blocks={ self .owns_blocks } ," ,
334+ " blocks=(" ,
335+ ]
336+
337+ # Loop through each block and show details
338+ for i , ((block_ref , metadata ), block_slice ) in enumerate (
339+ zip (self .blocks , self .slices )
340+ ):
341+ row_str = (
342+ f"{ metadata .num_rows } rows"
343+ if metadata .num_rows is not None
344+ else "unknown rows"
345+ )
346+ bytes_str = f"{ metadata .size_bytes } bytes"
347+ slice_str = (
348+ f"slice={ block_slice } "
349+ if block_slice is not None
350+ else "slice=None (full block)"
351+ )
352+
353+ lines .append (f" { i } : { row_str } , { bytes_str } , { slice_str } " )
354+
355+ lines .append (" )" )
356+ lines .append (")" )
357+
358+ return "\n " .join (lines )
359+
155360
156361@dataclass
157362class _ObjectMetadata :
@@ -170,3 +375,25 @@ def _ref_bundles_iterator_to_block_refs_list(
170375 return [
171376 block_ref for ref_bundle in ref_bundles for block_ref in ref_bundle .block_refs
172377 ]
378+
379+
380+ def _iter_sliced_blocks (
381+ blocks : Iterable [Block ],
382+ slices : List [Optional [BlockSlice ]],
383+ ) -> Iterator [Block ]:
384+ blocks_list = list (blocks )
385+ for block , block_slice in zip (blocks_list , slices ):
386+ if block_slice is None :
387+ # None represents a full block - yield it as is
388+ yield block
389+ else :
390+ accessor = BlockAccessor .for_block (block )
391+ start = block_slice .start_offset
392+ end = block_slice .end_offset
393+ assert start <= end , "start must be less than end"
394+ assert start >= 0 , "start must be non-negative"
395+ assert (
396+ end <= accessor .num_rows ()
397+ ), "end must be less than or equal to the number of rows in the block"
398+
399+ yield accessor .slice (start , end , copy = False )
0 commit comments