Skip to content

Commit 48bd1f8

Browse files
authored
[data] revert "continue grabbing task state until response is not None" (#61066)
revert #60592, cherrypick #61064 Signed-off-by: Lonnie Liu <lonnie@anyscale.com>
1 parent 165b4aa commit 48bd1f8

File tree

2 files changed

+23
-27
lines changed

2 files changed

+23
-27
lines changed

python/ray/data/_internal/issue_detection/detectors/hanging_detector.py

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import time
33
from collections import defaultdict
44
from dataclasses import dataclass, field
5-
from typing import TYPE_CHECKING, Dict, List, Optional, Set
5+
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union
66

77
import ray
88
from ray.data._internal.issue_detection.issue_detector import (
@@ -35,7 +35,6 @@
3535
class HangingExecutionState:
3636
operator_id: str
3737
task_idx: int
38-
task_id: ray.TaskID
3938
task_state: Optional[TaskState]
4039
bytes_output: int
4140
start_time_hanging: float
@@ -113,7 +112,7 @@ def _create_issues(
113112
attempt_number = state.task_state.attempt_number
114113

115114
message = (
116-
f"A task (task_id={state.task_id}) of operator {op_name} (pid={pid}, node_id={node_id}, attempt={attempt_number}) has been running for {duration:.2f}s, which is longer"
115+
f"A task of operator {op_name} (pid={pid}, node_id={node_id}, attempt={attempt_number}) has been running for {duration:.2f}s, which is longer"
117116
f" than the average task duration of this operator ({avg_duration:.2f}s)."
118117
f" If this message persists, please check the stack trace of the "
119118
"task for potential hanging issues."
@@ -154,11 +153,29 @@ def detect(self) -> List[Issue]:
154153
prev_state_value is None
155154
or bytes_output != prev_state_value.bytes_output
156155
):
156+
task_state = None
157+
try:
158+
task_state: Union[
159+
TaskState, List[TaskState]
160+
] = ray.util.state.get_task(
161+
task_info.task_id.hex(),
162+
timeout=1.0,
163+
_explain=True,
164+
)
165+
if isinstance(task_state, list):
166+
# get the latest task
167+
task_state = max(
168+
task_state, key=lambda ts: ts.attempt_number
169+
)
170+
except Exception as e:
171+
logger.debug(
172+
f"Failed to grab task state with task_index={task_idx}, task_id={task_info.task_id}: {e}"
173+
)
174+
pass
157175
self._state_map[operator.id][task_idx] = HangingExecutionState(
158176
operator_id=operator.id,
159177
task_idx=task_idx,
160-
task_id=task_info.task_id,
161-
task_state=None,
178+
task_state=task_state,
162179
bytes_output=bytes_output,
163180
start_time_hanging=time.perf_counter(),
164181
)
@@ -177,10 +194,6 @@ def detect(self) -> List[Issue]:
177194
for task_idx, state_value in op_state_values.items():
178195
curr_time = time.perf_counter() - state_value.start_time_hanging
179196
if op_task_stats.count() >= self._op_task_stats_min_count:
180-
if state_value.task_state is None:
181-
state_value.task_state = get_latest_state_for_task(
182-
state_value.task_id
183-
)
184197
mean = op_task_stats.mean()
185198
stddev = op_task_stats.stddev()
186199
threshold = mean + self._op_task_stats_std_factor_threshold * stddev
@@ -198,20 +211,3 @@ def detect(self) -> List[Issue]:
198211

199212
def detection_time_interval_s(self) -> float:
200213
return self._detector_cfg.detection_time_interval_s
201-
202-
203-
def get_latest_state_for_task(task_id: ray.TaskID) -> TaskState | None:
204-
try:
205-
task_state: TaskState | List[TaskState] | None = ray.util.state.get_task(
206-
task_id.hex(),
207-
timeout=1.0,
208-
_explain=True,
209-
)
210-
if isinstance(task_state, list):
211-
# get the latest task
212-
task_state = max(task_state, key=lambda ts: ts.attempt_number)
213-
return task_state
214-
except Exception as e:
215-
logger.debug(f"Failed to grab task state with task_id={task_id}: {e}")
216-
pass
217-
return None

python/ray/data/tests/test_issue_detection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def f1(x):
123123
_ = ray.data.range(1).map(f1).materialize()
124124

125125
log_output = log_capture.getvalue()
126-
warn_msg = r"A task \(task_id=.+\) .+ \(pid=.+, node_id=.+, attempt=.+\) has been running for [\d\.]+s"
126+
warn_msg = r"A task of operator .+ \(pid=.+, node_id=.+, attempt=.+\) has been running for [\d\.]+s"
127127
assert re.search(warn_msg, log_output) is None, log_output
128128

129129
# # test hanging does log hanging warning

0 commit comments

Comments
 (0)