@@ -191,8 +191,6 @@ def __init__(
191191 self ._map_worker_cls = type (f"MapWorker({ self .name } )" , (_MapWorker ,), {})
192192 # Cached actor class.
193193 self ._actor_cls = None
194- # Whether no more submittable bundles will be added.
195- self ._inputs_done = False
196194 self ._actor_locality_enabled : Optional [bool ] = None
197195
198196 # Locality metrics
@@ -280,8 +278,18 @@ def start(self, options: ExecutionOptions):
280278 "enough resources for the requested actor pool."
281279 )
282280
283- def should_add_input (self ) -> bool :
284- return self ._actor_pool .num_free_task_slots () > 0
281+ def can_add_input (self ) -> bool :
282+ """NOTE: PLEASE READ CAREFULLY
283+
284+ This method has to abide by the following contract to guarantee Operator's
285+ ability to handle all provided inputs (liveness):
286+
287+ - This method should only return `True` when operator is guaranteed
288+ to be able to launch a task, meaning that subsequent `op.add_input(...)`
289+ should be able to launch a task.
290+
291+ """
292+ return self ._actor_task_selector .can_schedule_task ()
285293
286294 def _start_actor (
287295 self , labels : Dict [str , str ], logical_actor_id : str
@@ -317,33 +325,44 @@ def _task_done_callback(res_ref):
317325 if not has_actor :
318326 # Actor has already been killed.
319327 return
320- # Try to dispatch queued tasks.
321- self ._dispatch_tasks ()
322328
323329 self ._submit_metadata_task (
324330 res_ref ,
325331 lambda : _task_done_callback (res_ref ),
326332 )
327333 return actor , res_ref
328334
329- def _add_bundled_input (self , bundle : RefBundle ):
335+ def _try_schedule_task (self , bundle : RefBundle , strict : bool ):
330336 # Notify first input for deferred initialization (e.g., Iceberg schema evolution).
331337 self ._notify_first_input (bundle )
338+ # Enqueue input bundle
332339 self ._bundle_queue .add (bundle )
333340 self ._metrics .on_input_queued (bundle )
334- # Try to dispatch all bundles in the queue, including this new bundle.
335- self ._dispatch_tasks ()
336341
337- def _dispatch_tasks (self ):
338- """Try to dispatch tasks from the bundle buffer to the actor pool.
342+ if strict :
343+ # NOTE: In case of strict input handling protocol at least 1 task
344+ # must be launched. Therefore, we assert here that it was
345+ # verified that the task could be scheduled before invoking
346+ # this method
347+ assert self .can_add_input (), f"Operator { self } can not handle input!"
348+
349+ # Try to dispatch new tasks
350+ submitted = self ._try_schedule_tasks_internal (strict = strict )
351+
352+ if strict :
353+ assert (
354+ submitted >= 1
355+ ), f"Expected at least 1 task launched (launched { submitted } )"
356+
357+ def _try_schedule_tasks_internal (self , strict : bool ) -> int :
358+ """Try to dispatch tasks from the internal queue"""
359+
360+ num_submitted_tasks = 0
339361
340- This is called when:
341- * a new input bundle is added,
342- * a task finishes,
343- * a new worker has been created.
344- """
345362 for bundle , actor in self ._actor_task_selector .select_actors (
346- self ._bundle_queue , self ._actor_locality_enabled
363+ self ._bundle_queue ,
364+ self ._actor_locality_enabled ,
365+ strict = strict ,
347366 ):
348367 # Submit the map task.
349368 self ._metrics .on_input_dequeued (bundle )
@@ -369,15 +388,15 @@ def _dispatch_tasks(self):
369388 def _task_done_callback (actor_to_return ):
370389 # Return the actor that was running the task to the pool.
371390 self ._actor_pool .on_task_completed (actor_to_return )
372- # Dipsatch more tasks.
373- self ._dispatch_tasks ()
374391
375392 from functools import partial
376393
377394 self ._submit_data_task (
378395 gen , bundle , partial (_task_done_callback , actor_to_return = actor )
379396 )
380397
398+ num_submitted_tasks += 1
399+
381400 # Update locality metrics
382401 if (
383402 self ._actor_pool .running_actors ()[actor ].actor_location
@@ -387,6 +406,8 @@ def _task_done_callback(actor_to_return):
387406 else :
388407 self ._locality_misses += 1
389408
409+ return num_submitted_tasks
410+
390411 def _refresh_actor_cls (self ):
391412 """When `self._ray_remote_args_fn` is specified, this method should
392413 be called prior to initializing the new worker in order to get new
@@ -405,15 +426,28 @@ def _refresh_actor_cls(self):
405426 self ._actor_cls = ray .remote (** remote_args )(self ._map_worker_cls )
406427 return new_and_overriden_remote_args
407428
429+ def has_next (self ) -> bool :
430+ # In case there are still enqueued bundles remaining, try to
431+ # dispatch tasks.
432+ if self ._inputs_complete and self ._bundle_queue .num_blocks () > 0 :
433+ # NOTE: That no more than 1 bundle is expected to be in the queue
434+ # upon inputs completion (the one that was pending in the bundler)
435+ if self ._bundle_queue .num_bundles () > 1 :
436+ logger .warning (
437+ f"Expected 1 bundle to remain in the input queue of { self } "
438+ f"(found { self ._bundle_queue .num_bundles ()} bundles, with { self ._bundle_queue .num_blocks ()} blocks)"
439+ )
440+
441+ # Schedule tasks handling remaining bundles
442+ self ._try_schedule_tasks_internal (strict = False )
443+
444+ return super ().has_next ()
445+
408446 def all_inputs_done (self ):
409447 # Call base implementation to handle any leftover bundles. This may or may not
410448 # trigger task dispatch.
411449 super ().all_inputs_done ()
412450
413- # Mark inputs as done so future task dispatch will kill all inactive workers
414- # once the bundle queue is exhausted.
415- self ._inputs_done = True
416-
417451 if self ._metrics .num_inputs_received < self ._actor_pool .min_size ():
418452 warnings .warn (
419453 f"The minimum number of concurrent actors for '{ self .name } ' is set to "
@@ -565,22 +599,11 @@ def min_scheduling_resources(
565599 return self ._actor_pool .per_actor_resource_usage ()
566600
567601 def update_resource_usage (self ) -> None :
568- """Updates resources usage."""
569- for actor in self ._actor_pool .get_running_actor_refs ():
570- actor_state = actor ._get_local_state ()
571- if actor_state in (None , gcs_pb2 .ActorTableData .ActorState .DEAD ):
572- # actor._get_local_state can return None if the state is Unknown
573- # If actor_state is None or dead, there is nothing to do.
574- continue
575- elif actor_state != gcs_pb2 .ActorTableData .ActorState .ALIVE :
576- # The actors can be either ALIVE or RESTARTING here because they will
577- # be restarted indefinitely until execution finishes.
578- assert (
579- actor_state == gcs_pb2 .ActorTableData .ActorState .RESTARTING
580- ), actor_state
581- self ._actor_pool .update_running_actor_state (actor , True )
582- else :
583- self ._actor_pool .update_running_actor_state (actor , False )
602+ """Updates internal state"""
603+
604+ # Trigger Actor Pool's state refresh
605+ self ._actor_pool .refresh_actor_state ()
606+ self ._actor_task_selector .refresh_state ()
584607
585608 def get_actor_info (self ) -> _ActorPoolInfo :
586609 """Returns Actor counts for Alive, Restarting and Pending Actors."""
@@ -697,15 +720,40 @@ def __init__(self, actor_pool: "_ActorPool"):
697720 """
698721 self ._actor_pool = actor_pool
699722
723+ def refresh_state (self ):
724+ """Callback to refresh selector's state that might depend on external data
725+
726+ NOTE: This data has to be snapshotted inside the selector, and
727+ can only change upon this method invocation"""
728+ pass
729+
730+ @abstractmethod
731+ def can_schedule_task (self ) -> bool :
732+ """Checks whether there are actors available to schedule at least 1 task
733+
734+ NOTE: This method has to be consistent with `select_actors(...)` method, ie
735+
736+ - If `can_schedule_task` returns `True`, then
737+ - `select_actors` must return at least 1 actor
738+
739+ TODO deduplicate with select_actors
740+ """
741+ ...
742+
700743 @abstractmethod
701744 def select_actors (
702- self , input_queue : QueueWithRemoval , actor_locality_enabled : bool
745+ self ,
746+ input_queue : QueueWithRemoval ,
747+ actor_locality_enabled : bool ,
748+ strict : bool ,
703749 ) -> Iterator [Tuple [RefBundle , ActorHandle ]]:
704750 """Select actors for bundles in the input queue.
705751
706752 Args:
707753 input_queue: The input queue to select actors for.
708754 actor_locality_enabled: Whether actor locality is enabled.
755+ strict: Controls whether strict input handling protocol is enforced,
756+ requiring at least 1 bundle to be matched with an actor
709757
710758 Returns:
711759 Iterator of tuples of the bundle and the selected actor for that bundle.
@@ -718,43 +766,45 @@ class _ActorTaskSelectorImpl(_ActorTaskSelector):
718766 def __init__ (self , actor_pool : "_ActorPool" ):
719767 super ().__init__ (actor_pool )
720768
769+ def can_schedule_task (self ) -> bool :
770+ available_actors = self ._actor_pool .schedulable_actors ()
771+
772+ return len (available_actors ) > 0
773+
721774 def select_actors (
722- self , input_queue : QueueWithRemoval , actor_locality_enabled : bool
775+ self ,
776+ input_queue : QueueWithRemoval ,
777+ actor_locality_enabled : bool ,
778+ strict : bool ,
723779 ) -> Iterator [Tuple [RefBundle , ActorHandle ]]:
724- """Picks actors for task submission based on busyness and locality."""
725- if not self ._actor_pool .running_actors ():
726- # Actor pool is empty or all actors are still pending.
727- return
780+ assert (
781+ not strict or self .can_schedule_task ()
782+ ), "select_actors(...) might not be invoked unless can_schedule_task(...) returns true"
728783
729784 while input_queue :
730785 # Filter out actors that are invalid, i.e. actors with number of tasks in
731786 # flight >= _max_tasks_in_flight or actor_state is not ALIVE.
732787 bundle = input_queue .peek_next ()
733- valid_actors = [
734- actor
735- for actor in self ._actor_pool .running_actors ()
736- if self ._actor_pool .running_actors ()[actor ].num_tasks_in_flight
737- < self ._actor_pool .max_tasks_in_flight_per_actor ()
738- and not self ._actor_pool .running_actors ()[actor ].is_restarting
739- ]
740-
741- if not valid_actors :
742- # All actors are at capacity or actor state is not ALIVE.
788+ # Fetch available actors
789+ available_actors = self ._actor_pool .schedulable_actors ()
790+ if not available_actors :
743791 return
744792
745793 # Rank all valid actors
746794 ranks = self ._rank_actors (
747- valid_actors , bundle if actor_locality_enabled else None
795+ available_actors , bundle if actor_locality_enabled else None
748796 )
749797
750798 assert len (ranks ) == len (
751- valid_actors
752- ), f"{ len (ranks )} != { len (valid_actors )} "
799+ available_actors
800+ ), f"{ len (ranks )} != { len (available_actors )} "
753801
754802 # Pick the actor with the highest rank (lower value, higher rank)
755- target_actor_idx = min (range (len (valid_actors )), key = lambda idx : ranks [idx ])
803+ target_actor_idx = min (
804+ range (len (available_actors )), key = lambda idx : ranks [idx ]
805+ )
756806
757- target_actor = valid_actors [target_actor_idx ]
807+ target_actor = available_actors [target_actor_idx ]
758808
759809 # We remove the bundle and yield the actor to the operator. We do not use pop()
760810 # in case the queue has changed the order of the bundles.
@@ -930,6 +980,9 @@ def num_tasks_in_flight(self) -> int:
930980 def initial_size (self ) -> int :
931981 return self ._initial_size
932982
983+ def get_actor_id (self , actor : ActorHandle ) -> str :
984+ return self ._actor_to_logical_id [actor ]
985+
933986 def _can_apply (self , config : ActorPoolScalingRequest ) -> bool :
934987 """Returns whether Actor Pool is able to execute scaling request"""
935988
@@ -1015,14 +1068,39 @@ def _create_actor(self) -> Tuple[ray.actor.ActorHandle, ObjectRef]:
10151068 def running_actors (self ) -> Dict [ray .actor .ActorHandle , _ActorState ]:
10161069 return self ._running_actors
10171070
1071+ def schedulable_actors (self ) -> List [ray .actor .ActorHandle ]:
1072+ return [
1073+ actor
1074+ for actor , state in self ._running_actors .items ()
1075+ if state .num_tasks_in_flight < self .max_tasks_in_flight_per_actor ()
1076+ and not state .is_restarting
1077+ ]
1078+
10181079 def on_task_submitted (self , actor : ray .actor .ActorHandle ):
10191080 self ._running_actors [actor ].num_tasks_in_flight += 1
10201081 self ._total_num_tasks_in_flight += 1
10211082
10221083 if self ._running_actors [actor ].num_tasks_in_flight == 1 :
10231084 self ._num_active_actors += 1
10241085
1025- def update_running_actor_state (
1086+ def refresh_actor_state (self ):
1087+ for actor in self .get_running_actor_refs ():
1088+ actor_state = actor ._get_local_state ()
1089+ if actor_state in (None , gcs_pb2 .ActorTableData .ActorState .DEAD ):
1090+ # actor._get_local_state can return None if the state is Unknown
1091+ # If actor_state is None or dead, there is nothing to do.
1092+ continue
1093+ elif actor_state != gcs_pb2 .ActorTableData .ActorState .ALIVE :
1094+ # The actors can be either ALIVE or RESTARTING here because they will
1095+ # be restarted indefinitely until execution finishes.
1096+ assert (
1097+ actor_state == gcs_pb2 .ActorTableData .ActorState .RESTARTING
1098+ ), actor_state
1099+ self ._update_running_actor_state (actor , True )
1100+ else :
1101+ self ._update_running_actor_state (actor , False )
1102+
1103+ def _update_running_actor_state (
10261104 self , actor : ray .actor .ActorHandle , is_restarting : bool
10271105 ) -> None :
10281106 """Update running actor state.
@@ -1099,7 +1177,7 @@ def on_task_completed(self, actor: ray.actor.ActorHandle):
10991177 def get_pending_actor_refs (self ) -> List [ray .ObjectRef ]:
11001178 return list (self ._pending_actors .keys ())
11011179
1102- def get_running_actor_refs (self ) -> List [ray . ObjectRef ]:
1180+ def get_running_actor_refs (self ) -> List [ActorHandle ]:
11031181 return list (self ._running_actors .keys ())
11041182
11051183 def get_logical_ids (self ) -> List [str ]:
0 commit comments