Skip to content

Commit 3401b03

Browse files
xinyuangui2gemini-code-assist[bot]justinvyu
authored andcommitted
[Data] Add time to first batch metric for dataset iterators (ray-project#55758)
The time to first batch usually takes longer time than the subsequent batches. This is because the time to first batch includes the time needed for the pipeline to warm up. The iterator receives the batch once the first few blocks have made it through all stages of the data pipeline and piped to the train worker consumers. Since we do prefetching and the data pipeline is in a steady state, so the time to produce subsequent batches is much lower. In this PR, we added a metric to track the time to first batch. --------- Signed-off-by: xgui <xgui@anyscale.com> Signed-off-by: Xinyuan <43737116+xinyuangui2@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Justin Yu <justinvyu@anyscale.com> Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
1 parent 0e500f1 commit 3401b03

File tree

3 files changed

+39
-3
lines changed

3 files changed

+39
-3
lines changed

python/ray/data/_internal/block_batching/iter_batches.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def __init__(
135135
if actor_prefetcher_enabled
136136
else WaitBlockPrefetcher()
137137
)
138+
self._yielded_first_batch = False
138139

139140
def _prefetch_blocks(
140141
self, ref_bundles: Iterator[RefBundle]
@@ -235,15 +236,29 @@ def __iter__(self) -> Iterator[DataBatch]:
235236
return self._iter_batches()
236237

237238
def before_epoch_start(self):
238-
pass
239+
self._yielded_first_batch = False
239240

240241
def after_epoch_end(self):
241242
StatsManager.clear_iteration_metrics(self._dataset_tag)
242243

243244
@contextmanager
244245
def get_next_batch_context(self):
245-
with self._stats.iter_total_blocked_s.timer() if self._stats else nullcontext():
246-
yield
246+
try:
247+
if self._stats:
248+
# Always track total blocked time
249+
total_timer = self._stats.iter_total_blocked_s.timer()
250+
# Also track the time until the first batch is ready
251+
first_batch_ready_timer = (
252+
self._stats.iter_time_to_first_batch_s.timer()
253+
if not self._yielded_first_batch
254+
else nullcontext()
255+
)
256+
with total_timer, first_batch_ready_timer:
257+
yield
258+
else:
259+
yield
260+
finally:
261+
self._yielded_first_batch = True
247262

248263
@contextmanager
249264
def yield_batch_context(self, batch: Batch):

python/ray/data/_internal/stats.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,12 @@ def __init__(self, max_stats=1000):
280280
description="Seconds user thread is blocked by iter_batches()",
281281
tag_keys=iter_tag_keys,
282282
)
283+
self.time_to_first_batch_s = Gauge(
284+
"data_iter_time_to_first_batch_seconds",
285+
description="Total time spent waiting for the first batch after starting iteration. "
286+
"This includes the dataset pipeline warmup time. This metric is accumulated across different epochs.",
287+
tag_keys=iter_tag_keys,
288+
)
283289
self.iter_user_s = Gauge(
284290
"data_iter_user_seconds",
285291
description="Seconds spent in user code",
@@ -469,6 +475,7 @@ def update_iteration_metrics(
469475
):
470476
tags = self._create_tags(dataset_tag)
471477
self.iter_total_blocked_s.set(stats.iter_total_blocked_s.get(), tags)
478+
self.time_to_first_batch_s.set(stats.iter_time_to_first_batch_s.get(), tags)
472479
self.iter_user_s.set(stats.iter_user_s.get(), tags)
473480
self.iter_initialize_s.set(stats.iter_initialize_s.get(), tags)
474481

@@ -948,6 +955,7 @@ def __init__(
948955
self.iter_format_batch_s: Timer = Timer()
949956
self.iter_collate_batch_s: Timer = Timer()
950957
self.iter_finalize_batch_s: Timer = Timer()
958+
self.iter_time_to_first_batch_s: Timer = Timer()
951959
self.iter_total_blocked_s: Timer = Timer()
952960
self.iter_user_s: Timer = Timer()
953961
self.iter_initialize_s: Timer = Timer()
@@ -1003,6 +1011,7 @@ def to_summary(self) -> "DatasetStatsSummary":
10031011
self.iter_format_batch_s,
10041012
self.iter_collate_batch_s,
10051013
self.iter_finalize_batch_s,
1014+
self.iter_time_to_first_batch_s,
10061015
self.iter_total_blocked_s,
10071016
self.iter_user_s,
10081017
self.iter_initialize_s,
@@ -1642,6 +1651,8 @@ class IterStatsSummary:
16421651
collate_time: Timer
16431652
# Time spent in finalize_fn, in seconds
16441653
finalize_batch_time: Timer
1654+
# Time user thread is blocked waiting for first batch
1655+
time_to_first_batch: Timer
16451656
# Total time user thread is blocked by iter_batches
16461657
block_time: Timer
16471658
# Time spent in user code, in seconds
@@ -1665,6 +1676,7 @@ def to_string(self) -> str:
16651676
out = ""
16661677
if (
16671678
self.block_time.get()
1679+
or self.time_to_first_batch.get()
16681680
or self.total_time.get()
16691681
or self.get_time.get()
16701682
or self.next_time.get()
@@ -1685,6 +1697,11 @@ def to_string(self) -> str:
16851697
" * Total time user thread is blocked by Ray Data iter_batches: "
16861698
"{}\n".format(fmt(self.block_time.get()))
16871699
)
1700+
if self.time_to_first_batch.get():
1701+
out += (
1702+
" * Total time spent waiting for the first batch after starting iteration: "
1703+
"{}\n".format(fmt(self.time_to_first_batch.get()))
1704+
)
16881705
if self.user_time.get():
16891706
out += " * Total execution time for user thread: {}\n".format(
16901707
fmt(self.user_time.get())

python/ray/data/tests/test_stats.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,7 @@ def test_streaming_split_stats(ray_start_regular_shared, restore_data_context):
395395
* Total time overall: T
396396
* Total time in Ray Data iterator initialization code: T
397397
* Total time user thread is blocked by Ray Data iter_batches: T
398+
* Total time spent waiting for the first batch after starting iteration: T
398399
* Total execution time for user thread: T
399400
* Batch iteration time breakdown (summed across prefetch threads):
400401
* In ray.get(): T min, T max, T avg, T total
@@ -577,6 +578,7 @@ def test_dataset_stats_basic(
577578
f"* Total time overall: T\n"
578579
f" * Total time in Ray Data iterator initialization code: T\n"
579580
f" * Total time user thread is blocked by Ray Data iter_batches: T\n"
581+
f" * Total time spent waiting for the first batch after starting iteration: T\n"
580582
f" * Total execution time for user thread: T\n"
581583
f"* Batch iteration time breakdown (summed across prefetch threads):\n"
582584
f" * In ray.get(): T min, T max, T avg, T total\n"
@@ -618,6 +620,7 @@ def test_block_location_nums(ray_start_regular_shared, restore_data_context):
618620
f"* Total time overall: T\n"
619621
f" * Total time in Ray Data iterator initialization code: T\n"
620622
f" * Total time user thread is blocked by Ray Data iter_batches: T\n"
623+
f" * Total time spent waiting for the first batch after starting iteration: T\n"
621624
f" * Total execution time for user thread: T\n"
622625
f"* Batch iteration time breakdown (summed across prefetch threads):\n"
623626
f" * In ray.get(): T min, T max, T avg, T total\n"
@@ -1363,6 +1366,7 @@ def test_streaming_stats_full(ray_start_regular_shared, restore_data_context):
13631366
* Total time overall: T
13641367
* Total time in Ray Data iterator initialization code: T
13651368
* Total time user thread is blocked by Ray Data iter_batches: T
1369+
* Total time spent waiting for the first batch after starting iteration: T
13661370
* Total execution time for user thread: T
13671371
* Batch iteration time breakdown (summed across prefetch threads):
13681372
* In ray.get(): T min, T max, T avg, T total

0 commit comments

Comments
 (0)