22import time
33from collections import defaultdict
44from dataclasses import dataclass , field
5- from typing import TYPE_CHECKING , Dict , List , Optional , Set
5+ from typing import TYPE_CHECKING , Dict , List , Optional , Set , Union
66
77import ray
88from ray .data ._internal .issue_detection .issue_detector import (
3535class HangingExecutionState :
3636 operator_id : str
3737 task_idx : int
38- task_id : ray .TaskID
3938 task_state : Optional [TaskState ]
4039 bytes_output : int
4140 start_time_hanging : float
@@ -113,7 +112,7 @@ def _create_issues(
113112 attempt_number = state .task_state .attempt_number
114113
115114 message = (
116- f"A task (task_id= { state . task_id } ) of operator { op_name } (pid={ pid } , node_id={ node_id } , attempt={ attempt_number } ) has been running for { duration :.2f} s, which is longer"
115+ f"A task of operator { op_name } (pid={ pid } , node_id={ node_id } , attempt={ attempt_number } ) has been running for { duration :.2f} s, which is longer"
117116 f" than the average task duration of this operator ({ avg_duration :.2f} s)."
118117 f" If this message persists, please check the stack trace of the "
119118 "task for potential hanging issues."
@@ -154,11 +153,29 @@ def detect(self) -> List[Issue]:
154153 prev_state_value is None
155154 or bytes_output != prev_state_value .bytes_output
156155 ):
156+ task_state = None
157+ try :
158+ task_state : Union [
159+ TaskState , List [TaskState ]
160+ ] = ray .util .state .get_task (
161+ task_info .task_id .hex (),
162+ timeout = 1.0 ,
163+ _explain = True ,
164+ )
165+ if isinstance (task_state , list ):
166+ # get the latest task
167+ task_state = max (
168+ task_state , key = lambda ts : ts .attempt_number
169+ )
170+ except Exception as e :
171+ logger .debug (
172+ f"Failed to grab task state with task_index={ task_idx } , task_id={ task_info .task_id } : { e } "
173+ )
174+ pass
157175 self ._state_map [operator .id ][task_idx ] = HangingExecutionState (
158176 operator_id = operator .id ,
159177 task_idx = task_idx ,
160- task_id = task_info .task_id ,
161- task_state = None ,
178+ task_state = task_state ,
162179 bytes_output = bytes_output ,
163180 start_time_hanging = time .perf_counter (),
164181 )
@@ -177,10 +194,6 @@ def detect(self) -> List[Issue]:
177194 for task_idx , state_value in op_state_values .items ():
178195 curr_time = time .perf_counter () - state_value .start_time_hanging
179196 if op_task_stats .count () >= self ._op_task_stats_min_count :
180- if state_value .task_state is None :
181- state_value .task_state = get_latest_state_for_task (
182- state_value .task_id
183- )
184197 mean = op_task_stats .mean ()
185198 stddev = op_task_stats .stddev ()
186199 threshold = mean + self ._op_task_stats_std_factor_threshold * stddev
@@ -198,20 +211,3 @@ def detect(self) -> List[Issue]:
198211
199212 def detection_time_interval_s (self ) -> float :
200213 return self ._detector_cfg .detection_time_interval_s
201-
202-
203- def get_latest_state_for_task (task_id : ray .TaskID ) -> TaskState | None :
204- try :
205- task_state : TaskState | List [TaskState ] | None = ray .util .state .get_task (
206- task_id .hex (),
207- timeout = 1.0 ,
208- _explain = True ,
209- )
210- if isinstance (task_state , list ):
211- # get the latest task
212- task_state = max (task_state , key = lambda ts : ts .attempt_number )
213- return task_state
214- except Exception as e :
215- logger .debug (f"Failed to grab task state with task_id={ task_id } : { e } " )
216- pass
217- return None
0 commit comments