Skip to content
Merged
19 changes: 17 additions & 2 deletions python/ray/data/_internal/block_batching/iter_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def __init__(
if actor_prefetcher_enabled
else WaitBlockPrefetcher()
)
self.waiting_for_first_batch = True

def _prefetch_blocks(
self, ref_bundles: Iterator[RefBundle]
Expand Down Expand Up @@ -242,8 +243,22 @@ def after_epoch_end(self):

@contextmanager
def get_next_batch_context(self):
with self._stats.iter_total_blocked_s.timer() if self._stats else nullcontext():
yield
try:
if self._stats:
# Always track total blocked time
total_timer = self._stats.iter_total_blocked_s.timer()
# Also track first batch blocked time if this is the first batch
first_batch_timer = (
self._stats.iter_first_batch_blocked_s.timer()
if self.waiting_for_first_batch
else nullcontext()
)
with total_timer, first_batch_timer:
yield
else:
yield
finally:
self.waiting_for_first_batch = False

@contextmanager
def yield_batch_context(self, batch: Batch):
Expand Down
18 changes: 18 additions & 0 deletions python/ray/data/_internal/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,11 @@ def __init__(self, max_stats=1000):
description="Seconds user thread is blocked by iter_batches()",
tag_keys=iter_tag_keys,
)
self.iter_first_batch_blocked_s = Gauge(
"data_iter_first_batch_blocked_seconds",
description="Seconds user thread is blocked waiting for first batch",
tag_keys=iter_tag_keys,
)
self.iter_user_s = Gauge(
"data_iter_user_seconds",
description="Seconds spent in user code",
Expand Down Expand Up @@ -469,6 +474,9 @@ def update_iteration_metrics(
):
tags = self._create_tags(dataset_tag)
self.iter_total_blocked_s.set(stats.iter_total_blocked_s.get(), tags)
self.iter_first_batch_blocked_s.set(
stats.iter_first_batch_blocked_s.get(), tags
)
self.iter_user_s.set(stats.iter_user_s.get(), tags)
self.iter_initialize_s.set(stats.iter_initialize_s.get(), tags)

Expand Down Expand Up @@ -948,6 +956,7 @@ def __init__(
self.iter_format_batch_s: Timer = Timer()
self.iter_collate_batch_s: Timer = Timer()
self.iter_finalize_batch_s: Timer = Timer()
self.iter_first_batch_blocked_s: Timer = Timer()
self.iter_total_blocked_s: Timer = Timer()
self.iter_user_s: Timer = Timer()
self.iter_initialize_s: Timer = Timer()
Expand Down Expand Up @@ -1003,6 +1012,7 @@ def to_summary(self) -> "DatasetStatsSummary":
self.iter_format_batch_s,
self.iter_collate_batch_s,
self.iter_finalize_batch_s,
self.iter_first_batch_blocked_s,
self.iter_total_blocked_s,
self.iter_user_s,
self.iter_initialize_s,
Expand Down Expand Up @@ -1642,6 +1652,8 @@ class IterStatsSummary:
collate_time: Timer
# Time spent in finalize_fn, in seconds
finalize_batch_time: Timer
# Time user thread is blocked waiting for first batch
first_batch_block_time: Timer
# Total time user thread is blocked by iter_batches
block_time: Timer
# Time spent in user code, in seconds
Expand All @@ -1665,6 +1677,7 @@ def to_string(self) -> str:
out = ""
if (
self.block_time.get()
or self.first_batch_block_time.get()
or self.total_time.get()
or self.get_time.get()
or self.next_time.get()
Expand All @@ -1685,6 +1698,11 @@ def to_string(self) -> str:
" * Total time user thread is blocked by Ray Data iter_batches: "
"{}\n".format(fmt(self.block_time.get()))
)
if self.first_batch_block_time.get():
out += (
" * Total time user thread is blocked waiting for first batch: "
"{}\n".format(fmt(self.first_batch_block_time.get()))
)
if self.user_time.get():
out += " * Total execution time for user thread: {}\n".format(
fmt(self.user_time.get())
Expand Down
4 changes: 4 additions & 0 deletions python/ray/data/tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ def test_streaming_split_stats(ray_start_regular_shared, restore_data_context):
* Total time overall: T
* Total time in Ray Data iterator initialization code: T
* Total time user thread is blocked by Ray Data iter_batches: T
* Total time user thread is blocked waiting for first batch: T
* Total execution time for user thread: T
* Batch iteration time breakdown (summed across prefetch threads):
* In ray.get(): T min, T max, T avg, T total
Expand Down Expand Up @@ -577,6 +578,7 @@ def test_dataset_stats_basic(
f"* Total time overall: T\n"
f" * Total time in Ray Data iterator initialization code: T\n"
f" * Total time user thread is blocked by Ray Data iter_batches: T\n"
f" * Total time user thread is blocked waiting for first batch: T\n"
f" * Total execution time for user thread: T\n"
f"* Batch iteration time breakdown (summed across prefetch threads):\n"
f" * In ray.get(): T min, T max, T avg, T total\n"
Expand Down Expand Up @@ -618,6 +620,7 @@ def test_block_location_nums(ray_start_regular_shared, restore_data_context):
f"* Total time overall: T\n"
f" * Total time in Ray Data iterator initialization code: T\n"
f" * Total time user thread is blocked by Ray Data iter_batches: T\n"
f" * Total time user thread is blocked waiting for first batch: T\n"
f" * Total execution time for user thread: T\n"
f"* Batch iteration time breakdown (summed across prefetch threads):\n"
f" * In ray.get(): T min, T max, T avg, T total\n"
Expand Down Expand Up @@ -1363,6 +1366,7 @@ def test_streaming_stats_full(ray_start_regular_shared, restore_data_context):
* Total time overall: T
* Total time in Ray Data iterator initialization code: T
* Total time user thread is blocked by Ray Data iter_batches: T
* Total time user thread is blocked waiting for first batch: T
* Total execution time for user thread: T
* Batch iteration time breakdown (summed across prefetch threads):
* In ray.get(): T min, T max, T avg, T total
Expand Down