Skip to content

Commit ee1bd47

Browse files
alexeykudinkinxinyuangui2
authored andcommitted
[Data] Fixing prefetch loop to avoid being blocked on the block being fetched (ray-project#57613)
<!-- Thank you for your contribution! Please review https://github.com/ray-project/ray/blob/master/CONTRIBUTING.rst before opening a pull request. --> <!-- Please add a reviewer to the assignee section when you create a PR. If you don't have the access to it, we will shortly find a reviewer and assign them to your PR. --> ## Why are these changes needed? 1. Fixing prefetcher loop to avoid being blocked on the next block being fetched 2. Adding missing metrics for `BatchIterator` ## Related issue number <!-- For example: "Closes ray-project#1234" --> ## Checks - [ ] I've signed off every commit(by using the -s flag, i.e., `git commit -s`) in this PR. - [ ] I've run pre-commit jobs to lint the changes in this PR. ([pre-commit setup](https://docs.ray.io/en/latest/ray-contribute/getting-involved.html#lint-and-formatting)) - [ ] I've included any doc changes needed for https://docs.ray.io/en/master/. - [ ] I've added any new APIs to the API Reference. For example, if I added a method in Tune, I've added it in `doc/source/tune/api/` under the corresponding `.rst` file. - [ ] I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/ - Testing Strategy - [ ] Unit tests - [ ] Release tests - [ ] This PR is not tested :( --------- Signed-off-by: Alexey Kudinkin <ak@anyscale.com> Signed-off-by: xgui <xgui@anyscale.com>
1 parent 3f04109 commit ee1bd47

File tree

2 files changed

+63
-18
lines changed

2 files changed

+63
-18
lines changed

python/ray/data/_internal/block_batching/util.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -232,22 +232,30 @@ def __init__(self):
232232
self._thread.start()
233233

234234
def _run(self):
235-
while True:
235+
while not self._stopped:
236236
try:
237-
blocks_to_wait = []
238237
with self._condition:
239-
if len(self._blocks) > 0:
240-
blocks_to_wait, self._blocks = self._blocks[:], []
241-
else:
242-
if self._stopped:
243-
return
244-
blocks_to_wait = []
238+
if len(self._blocks) == 0:
239+
# Park, waiting for notification that prefetching
240+
# should resume
245241
self._condition.wait()
246-
if len(blocks_to_wait) > 0:
247-
ray.wait(blocks_to_wait, num_returns=1, fetch_local=True)
242+
243+
blocks_to_fetch, self._blocks = self._blocks[:], []
244+
245+
if len(blocks_to_fetch) > 0:
246+
ray.wait(
247+
blocks_to_fetch,
248+
num_returns=1,
249+
# NOTE: We deliberately setting timeout to 0 to avoid
250+
# blocking the fetching thread unnecessarily
251+
timeout=0,
252+
fetch_local=True,
253+
)
248254
except Exception:
249255
logger.exception("Error in prefetcher thread.")
250256

257+
logger.info("Exiting prefetcher's background thread")
258+
251259
def prefetch_blocks(self, blocks: List[ObjectRef[Block]]):
252260
with self._condition:
253261
if self._stopped:

python/ray/data/_internal/stats.py

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -278,17 +278,45 @@ def __init__(self, max_stats=1000):
278278
self.per_node_metrics = self._create_prometheus_metrics_for_per_node_metrics()
279279

280280
iter_tag_keys = ("dataset",)
281-
self.iter_total_blocked_s = Gauge(
282-
"data_iter_total_blocked_seconds",
283-
description="Seconds user thread is blocked by iter_batches()",
284-
tag_keys=iter_tag_keys,
285-
)
281+
286282
self.time_to_first_batch_s = Gauge(
287283
"data_iter_time_to_first_batch_seconds",
288284
description="Total time spent waiting for the first batch after starting iteration. "
289285
"This includes the dataset pipeline warmup time. This metric is accumulated across different epochs.",
290286
tag_keys=iter_tag_keys,
291287
)
288+
289+
self.iter_block_fetching_s = Gauge(
290+
"data_iter_block_fetching_seconds",
291+
description="Seconds taken to fetch (with ray.get) blocks by iter_batches()",
292+
tag_keys=iter_tag_keys,
293+
)
294+
self.iter_batch_shaping_s = Gauge(
295+
"data_iter_batch_shaping_seconds",
296+
description="Seconds taken to shape batch from incoming blocks by iter_batches()",
297+
tag_keys=iter_tag_keys,
298+
)
299+
self.iter_batch_formatting_s = Gauge(
300+
"data_iter_batch_formatting_seconds",
301+
description="Seconds taken to format batches by iter_batches()",
302+
tag_keys=iter_tag_keys,
303+
)
304+
self.iter_batch_collating_s = Gauge(
305+
"data_iter_batch_collating_seconds",
306+
description="Seconds taken to collate batches by iter_batches()",
307+
tag_keys=iter_tag_keys,
308+
)
309+
self.iter_batch_finalizing_s = Gauge(
310+
"data_iter_batch_finalizing_seconds",
311+
description="Seconds taken to collate batches by iter_batches()",
312+
tag_keys=iter_tag_keys,
313+
)
314+
315+
self.iter_total_blocked_s = Gauge(
316+
"data_iter_total_blocked_seconds",
317+
description="Seconds user thread is blocked by iter_batches()",
318+
tag_keys=iter_tag_keys,
319+
)
292320
self.iter_user_s = Gauge(
293321
"data_iter_user_seconds",
294322
description="Seconds spent in user code",
@@ -517,9 +545,7 @@ def update_iteration_metrics(
517545
dataset_tag,
518546
):
519547
tags = self._create_tags(dataset_tag)
520-
self.iter_total_blocked_s.set(stats.iter_total_blocked_s.get(), tags)
521-
self.time_to_first_batch_s.set(stats.iter_time_to_first_batch_s.get(), tags)
522-
self.iter_user_s.set(stats.iter_user_s.get(), tags)
548+
523549
self.iter_initialize_s.set(stats.iter_initialize_s.get(), tags)
524550
self.iter_get_s.set(stats.iter_get_s.get(), tags)
525551
self.iter_next_batch_s.set(stats.iter_next_batch_s.get(), tags)
@@ -530,6 +556,17 @@ def update_iteration_metrics(
530556
self.iter_blocks_remote.set(stats.iter_blocks_remote, tags)
531557
self.iter_unknown_location.set(stats.iter_unknown_location, tags)
532558

559+
self.iter_block_fetching_s.set(stats.iter_get_s.get(), tags)
560+
self.iter_batch_shaping_s.set(stats.iter_next_batch_s.get(), tags)
561+
self.iter_batch_formatting_s.set(stats.iter_format_batch_s.get(), tags)
562+
self.iter_batch_collating_s.set(stats.iter_collate_batch_s.get(), tags)
563+
self.iter_batch_finalizing_s.set(stats.iter_finalize_batch_s.get(), tags)
564+
565+
self.time_to_first_batch_s.set(stats.iter_time_to_first_batch_s.get(), tags)
566+
567+
self.iter_total_blocked_s.set(stats.iter_total_blocked_s.get(), tags)
568+
self.iter_user_s.set(stats.iter_user_s.get(), tags)
569+
533570
def register_dataset(
534571
self,
535572
job_id: str,

0 commit comments

Comments
 (0)