Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/ray/data/_internal/execution/bundle_queue/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def num_blocks(self) -> int:
"""Return the total # of blocks across all bundles."""
return self._num_blocks

def num_bundles(self) -> int:
return self._num_bundles

def num_rows(self) -> int:
"""Return the total # of rows across all bundles."""
return self._num_rows
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,6 @@ def has_execution_finished(self) -> bool:
# - All input blocks have been ingested
# - Internal queue is empty
# - There are no active or pending tasks

return self._is_execution_marked_finished or (
self._inputs_complete
and self.num_active_tasks() == 0
Expand All @@ -450,8 +449,9 @@ def has_completed(self) -> bool:
# Draining the internal output queue is important to free object refs.
return (
self.has_execution_finished()
and not self.has_next()
and internal_output_queue_num_blocks == 0
# TODO following check is redundant; remove
and not self.has_next()
)

def get_stats(self) -> StatsDict:
Expand Down Expand Up @@ -553,7 +553,7 @@ def start(self, options: ExecutionOptions) -> None:
"""
self._started = True

def should_add_input(self) -> bool:
def can_add_input(self) -> bool:
"""Return whether it is desirable to add input to this operator right now.

Operators can customize the implementation of this method to apply additional
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,6 @@ def __init__(
self._map_worker_cls = type(f"MapWorker({self.name})", (_MapWorker,), {})
# Cached actor class.
self._actor_cls = None
# Whether no more submittable bundles will be added.
self._inputs_done = False
self._actor_locality_enabled: Optional[bool] = None

# Locality metrics
Expand Down Expand Up @@ -280,8 +278,18 @@ def start(self, options: ExecutionOptions):
"enough resources for the requested actor pool."
)

def should_add_input(self) -> bool:
return self._actor_pool.num_free_task_slots() > 0
def can_add_input(self) -> bool:
"""NOTE: PLEASE READ CAREFULLY

This method has to abide by the following contract to guarantee Operator's
ability to handle all provided inputs (liveness):

- This method should only return `True` when operator is guaranteed
to be able to launch a task, meaning that subsequent `op.add_input(...)`
should be able to launch a task.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can be handled later. The contract is kind of fragile, because there is no constraints on WHEN the next add_input will be called.
We should

  1. either make can_dadd_input and add_input atomic
  2. or introduce some boundaries at which the op's states can change

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the contract is that:

  • Before add_input, can_add_input must be called (which is done when we're selecting operator to dispatch to)
  • This is enforced through assertions inside add_input calling can_add_input


"""
return self._actor_task_selector.can_schedule_task()

def _start_actor(
self, labels: Dict[str, str], logical_actor_id: str
Expand Down Expand Up @@ -317,33 +325,44 @@ def _task_done_callback(res_ref):
if not has_actor:
# Actor has already been killed.
return
# Try to dispatch queued tasks.
self._dispatch_tasks()

self._submit_metadata_task(
res_ref,
lambda: _task_done_callback(res_ref),
)
return actor, res_ref

def _add_bundled_input(self, bundle: RefBundle):
def _try_schedule_task(self, bundle: RefBundle, strict: bool):
# Notify first input for deferred initialization (e.g., Iceberg schema evolution).
self._notify_first_input(bundle)
# Enqueue input bundle
self._bundle_queue.add(bundle)
self._metrics.on_input_queued(bundle)
# Try to dispatch all bundles in the queue, including this new bundle.
self._dispatch_tasks()

def _dispatch_tasks(self):
"""Try to dispatch tasks from the bundle buffer to the actor pool.
if strict:
# NOTE: In case of strict input handling protocol at least 1 task
# must be launched. Therefore, we assert here that it was
# verified that the task could be scheduled before invoking
# this method
assert self.can_add_input(), f"Operator {self} can not handle input!"

# Try to dispatch new tasks
submitted = self._try_schedule_tasks_internal(strict=strict)

if strict:
assert (
submitted >= 1
), f"Expected at least 1 task launched (launched {submitted})"

def _try_schedule_tasks_internal(self, strict: bool) -> int:
"""Try to dispatch tasks from the internal queue"""

num_submitted_tasks = 0

This is called when:
* a new input bundle is added,
* a task finishes,
* a new worker has been created.
"""
for bundle, actor in self._actor_task_selector.select_actors(
self._bundle_queue, self._actor_locality_enabled
self._bundle_queue,
self._actor_locality_enabled,
strict=strict,
):
# Submit the map task.
self._metrics.on_input_dequeued(bundle)
Expand All @@ -369,15 +388,15 @@ def _dispatch_tasks(self):
def _task_done_callback(actor_to_return):
# Return the actor that was running the task to the pool.
self._actor_pool.on_task_completed(actor_to_return)
# Dipsatch more tasks.
self._dispatch_tasks()

from functools import partial

self._submit_data_task(
gen, bundle, partial(_task_done_callback, actor_to_return=actor)
)

num_submitted_tasks += 1

# Update locality metrics
if (
self._actor_pool.running_actors()[actor].actor_location
Expand All @@ -387,6 +406,8 @@ def _task_done_callback(actor_to_return):
else:
self._locality_misses += 1

return num_submitted_tasks

def _refresh_actor_cls(self):
"""When `self._ray_remote_args_fn` is specified, this method should
be called prior to initializing the new worker in order to get new
Expand All @@ -405,15 +426,28 @@ def _refresh_actor_cls(self):
self._actor_cls = ray.remote(**remote_args)(self._map_worker_cls)
return new_and_overriden_remote_args

def has_next(self) -> bool:
# In case there are still enqueued bundles remaining, try to
# dispatch tasks.
if self._inputs_complete and self._bundle_queue.num_blocks() > 0:
# NOTE: That no more than 1 bundle is expected to be in the queue
# upon inputs completion (the one that was pending in the bundler)
if self._bundle_queue.num_bundles() >= 1:
logger.warning(
f"Expected 1 bundle to remain in the input queue of {self} "
f"(found {self._bundle_queue.num_bundles()} bundles, with {self._bundle_queue.num_blocks()} blocks)"
)

# Schedule tasks handling remaining bundles
self._try_schedule_tasks_internal(strict=False)

return super().has_next()

def all_inputs_done(self):
# Call base implementation to handle any leftover bundles. This may or may not
# trigger task dispatch.
super().all_inputs_done()

# Mark inputs as done so future task dispatch will kill all inactive workers
# once the bundle queue is exhausted.
self._inputs_done = True

if self._metrics.num_inputs_received < self._actor_pool.min_size():
warnings.warn(
f"The minimum number of concurrent actors for '{self.name}' is set to "
Expand Down Expand Up @@ -565,22 +599,11 @@ def min_scheduling_resources(
return self._actor_pool.per_actor_resource_usage()

def update_resource_usage(self) -> None:
"""Updates resources usage."""
for actor in self._actor_pool.get_running_actor_refs():
actor_state = actor._get_local_state()
if actor_state in (None, gcs_pb2.ActorTableData.ActorState.DEAD):
# actor._get_local_state can return None if the state is Unknown
# If actor_state is None or dead, there is nothing to do.
continue
elif actor_state != gcs_pb2.ActorTableData.ActorState.ALIVE:
# The actors can be either ALIVE or RESTARTING here because they will
# be restarted indefinitely until execution finishes.
assert (
actor_state == gcs_pb2.ActorTableData.ActorState.RESTARTING
), actor_state
self._actor_pool.update_running_actor_state(actor, True)
else:
self._actor_pool.update_running_actor_state(actor, False)
"""Updates internal state"""

# Trigger Actor Pool's state refresh
self._actor_pool.refresh_actor_state()
self._actor_task_selector.refresh_state()

def get_actor_info(self) -> _ActorPoolInfo:
"""Returns Actor counts for Alive, Restarting and Pending Actors."""
Expand Down Expand Up @@ -697,9 +720,32 @@ def __init__(self, actor_pool: "_ActorPool"):
"""
self._actor_pool = actor_pool

def refresh_state(self):
"""Callback to refresh selector's state that might depend on external data

NOTE: This data has to be snapshotted inside the selector, and
can only change upon this method invocation"""
pass

@abstractmethod
def can_schedule_task(self) -> bool:
"""Checks whether there are actors available to schedule at least 1 task

NOTE: This method has to be consistent with `select_actors(...)` method, ie

- If `can_schedule_task` returns `True`, then
- `select_actors` must return at least 1 actor

TODO deduplicate with select_actors
"""
...

@abstractmethod
def select_actors(
self, input_queue: QueueWithRemoval, actor_locality_enabled: bool
self,
input_queue: QueueWithRemoval,
actor_locality_enabled: bool,
strict: bool,
) -> Iterator[Tuple[RefBundle, ActorHandle]]:
"""Select actors for bundles in the input queue.

Expand All @@ -718,43 +764,47 @@ class _ActorTaskSelectorImpl(_ActorTaskSelector):
def __init__(self, actor_pool: "_ActorPool"):
super().__init__(actor_pool)

def can_schedule_task(self) -> bool:
available_actors = self._actor_pool.schedulable_actors()

return len(available_actors) > 0

def select_actors(
self, input_queue: QueueWithRemoval, actor_locality_enabled: bool
self,
input_queue: QueueWithRemoval,
actor_locality_enabled: bool,
strict: bool,
) -> Iterator[Tuple[RefBundle, ActorHandle]]:
"""Picks actors for task submission based on busyness and locality."""
if not self._actor_pool.running_actors():
# Actor pool is empty or all actors are still pending.
return

assert (
not strict or self.can_schedule_task()
), "select_actors(...) might not be invoked unless can_schedule_task(...) returns true"

while input_queue:
# Filter out actors that are invalid, i.e. actors with number of tasks in
# flight >= _max_tasks_in_flight or actor_state is not ALIVE.
bundle = input_queue.peek_next()
valid_actors = [
actor
for actor in self._actor_pool.running_actors()
if self._actor_pool.running_actors()[actor].num_tasks_in_flight
< self._actor_pool.max_tasks_in_flight_per_actor()
and not self._actor_pool.running_actors()[actor].is_restarting
]

if not valid_actors:
# All actors are at capacity or actor state is not ALIVE.
# Fetch available actors
available_actors = self._actor_pool.schedulable_actors()
if not available_actors:
return

# Rank all valid actors
ranks = self._rank_actors(
valid_actors, bundle if actor_locality_enabled else None
available_actors, bundle if actor_locality_enabled else None
)

assert len(ranks) == len(
valid_actors
), f"{len(ranks)} != {len(valid_actors)}"
available_actors
), f"{len(ranks)} != {len(available_actors)}"

# Pick the actor with the highest rank (lower value, higher rank)
target_actor_idx = min(range(len(valid_actors)), key=lambda idx: ranks[idx])
target_actor_idx = min(
range(len(available_actors)), key=lambda idx: ranks[idx]
)

target_actor = valid_actors[target_actor_idx]
target_actor = available_actors[target_actor_idx]

# We remove the bundle and yield the actor to the operator. We do not use pop()
# in case the queue has changed the order of the bundles.
Expand Down Expand Up @@ -930,6 +980,9 @@ def num_tasks_in_flight(self) -> int:
def initial_size(self) -> int:
return self._initial_size

def get_actor_id(self, actor: ActorHandle) -> str:
return self._actor_to_logical_id[actor]

def _can_apply(self, config: ActorPoolScalingRequest) -> bool:
"""Returns whether Actor Pool is able to execute scaling request"""

Expand Down Expand Up @@ -1015,14 +1068,39 @@ def _create_actor(self) -> Tuple[ray.actor.ActorHandle, ObjectRef]:
def running_actors(self) -> Dict[ray.actor.ActorHandle, _ActorState]:
return self._running_actors

def schedulable_actors(self) -> List[ray.actor.ActorHandle]:
return [
actor
for actor, state in self._running_actors.items()
if state.num_tasks_in_flight < self.max_tasks_in_flight_per_actor()
and not state.is_restarting
]

def on_task_submitted(self, actor: ray.actor.ActorHandle):
self._running_actors[actor].num_tasks_in_flight += 1
self._total_num_tasks_in_flight += 1

if self._running_actors[actor].num_tasks_in_flight == 1:
self._num_active_actors += 1

def update_running_actor_state(
def refresh_actor_state(self):
for actor in self.get_running_actor_refs():
actor_state = actor._get_local_state()
if actor_state in (None, gcs_pb2.ActorTableData.ActorState.DEAD):
# actor._get_local_state can return None if the state is Unknown
# If actor_state is None or dead, there is nothing to do.
continue
elif actor_state != gcs_pb2.ActorTableData.ActorState.ALIVE:
# The actors can be either ALIVE or RESTARTING here because they will
# be restarted indefinitely until execution finishes.
assert (
actor_state == gcs_pb2.ActorTableData.ActorState.RESTARTING
), actor_state
self._update_running_actor_state(actor, True)
else:
self._update_running_actor_state(actor, False)

def _update_running_actor_state(
self, actor: ray.actor.ActorHandle, is_restarting: bool
) -> None:
"""Update running actor state.
Expand Down Expand Up @@ -1099,7 +1177,7 @@ def on_task_completed(self, actor: ray.actor.ActorHandle):
def get_pending_actor_refs(self) -> List[ray.ObjectRef]:
return list(self._pending_actors.keys())

def get_running_actor_refs(self) -> List[ray.ObjectRef]:
def get_running_actor_refs(self) -> List[ActorHandle]:
return list(self._running_actors.keys())

def get_logical_ids(self) -> List[str]:
Expand Down
Loading