Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def __init__(
self._streaming_gen = streaming_gen
self._output_ready_callback = output_ready_callback
self._task_done_callback = task_done_callback
self._pending_block_pair = None

def get_waitable(self) -> ObjectRefGenerator:
return self._streaming_gen
Expand All @@ -128,8 +129,13 @@ def on_data_ready(self, max_bytes_to_read: Optional[int]) -> int:
"""
bytes_read = 0
while max_bytes_to_read is None or bytes_read < max_bytes_to_read:
block_ref, meta_ref = self._pending_block_pair or (
ray.ObjectRef.nil(),
ray.ObjectRef.nil(),
)
try:
block_ref = self._streaming_gen._next_sync(0)
if block_ref.is_nil():
block_ref = self._streaming_gen._next_sync(0)
if block_ref.is_nil():
# The generator currently doesn't have new output.
# And it's not stopped yet.
Expand All @@ -139,9 +145,18 @@ def on_data_ready(self, max_bytes_to_read: Optional[int]) -> int:
break

try:
if meta_ref.is_nil():
meta_ref = self._streaming_gen._next_sync(1)
if meta_ref.is_nil():
self._pending_block_pair = (block_ref, meta_ref)
break
meta_with_schema: "BlockMetadataWithSchema" = ray.get(
next(self._streaming_gen)
meta_ref, timeout=1
)
except ray.exceptions.GetTimeoutError:
logger.warning(f"Get Meta timeout for (block_ref={block_ref.hex()})")
self._pending_block_pair = (block_ref, meta_ref)
break
except StopIteration:
# The generator should always yield 2 values (block and metadata)
# each time. If we get a StopIteration here, it means an error
Expand All @@ -164,10 +179,15 @@ def on_data_ready(self, max_bytes_to_read: Optional[int]) -> int:
schema=meta_with_schema.schema,
),
)
self._pending_block_pair = None
bytes_read += meta.size_bytes

return bytes_read

@property
def pending_block_pair(self):
return self._pending_block_pair


class MetadataOpTask(OpTask):
"""Represents an OpTask that only handles metadata, instead of Block data."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ def process_completed_tasks(
# Process completed Ray tasks and notify operators.
num_errored_blocks = 0
if active_tasks:
ready, _ = ray.wait(
ready, remaining = ray.wait(
list(active_tasks.keys()),
num_returns=len(active_tasks),
fetch_local=False,
Expand All @@ -513,6 +513,12 @@ def process_completed_tasks(
for ref in ready:
state, task = active_tasks[ref]
ready_tasks_by_op[state].append(task)
for ref in remaining:
state, task = active_tasks[ref]
# If there are pending blocks, try to process them even if streaming_gen is not ready
# This may be because the (block, meta) pair processing encountered exceptions or timeouts, leaving data unconsumed
if isinstance(task, DataOpTask) and task.pending_block_pair:
ready_tasks_by_op[state].append(task)

for state, ready_tasks in ready_tasks_by_op.items():
ready_tasks = sorted(ready_tasks, key=lambda t: t.task_index())
Expand Down