Skip to content

Commit ea27046

Browse files
authored
[train][checkpoint] Add ray.train.get_all_reported_checkpoints method (#54555)
# Summary This PR adds a `ray.train.get_all_reported_checkpoints` method that allows users to get all the checkpoints they have reported from within their training function. This is different from [Result](https://docs.ray.io/en/latest/train/user-guides/results.html) in two ways: * It is called from the training function on the training worker instead of from the driver * It can be called while training is still in progress # Implementation Notes The main idea is to use a worker-side counter and controller-side counter as follows: * Train worker: `ray.train.report` increments a `num_reported_checkpoints` counter and puts the training result into its queue * Train controller: polls the training results from all worker, registers the checkpoint, increments `num_reported_checkpoints`, and then creates an asyncio task to notify asyncio Condition. This works because asyncio Ray actors should always have an event loop. * Train worker: `get_all_reported_results` uses an asyncio.Condition to wait until the worker-side `num_reported_checkpoints` counter matches its controller-side counterpart before returning the checkpoints. This ensures that we wait for all pending reports to finish. It has access to the controller actor through `init_train_context`. `get_checkpoint` should be unaffected because it uses the local checkpoint; we can consider changing it to use the "centrally committed" checkpoint in the future. # Testing I ran the [ray train pytorch example](https://docs.ray.io/en/latest/train/getting-started-pytorch.html) and called `ray.train.get_all_reported_checkpoints` at the end of each epoch. The results are as expected; here are a few examples ` epoch 4: [TrainingResult(checkpoint=Checkpoint(filesystem=local, path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-34-52.538994), metrics={'loss': 0.24510294198989868, 'epoch': 0}), TrainingResult(checkpoint=Checkpoint(filesystem=local, path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-07.511694), metrics={'loss': 0.23799467086791992, 'epoch': 1}), TrainingResult(checkpoint=Checkpoint(filesystem=local, path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-24.355974), metrics={'loss': 0.39628422260284424, 'epoch': 2}), TrainingResult(checkpoint=Checkpoint(filesystem=local, path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-40.273211), metrics={'loss': 0.15193207561969757, 'epoch': 3}), TrainingResult(checkpoint=Checkpoint(filesystem=local, path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-56.178119), metrics={'loss': 0.17416314780712128, 'epoch': 4})] ` ` epoch 9: [TrainingResult(checkpoint=Checkpoint(filesystem=local, path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-34-52.538994), metrics={'loss': 0.24510294198989868, 'epoch': 0}), TrainingResult(checkpoint=Checkpoint(filesystem=local, path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-07.511694), metrics={'loss': 0.23799467086791992, 'epoch': 1}), TrainingResult(checkpoint=Checkpoint(filesystem=local, path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-24.355974), metrics={'loss': 0.39628422260284424, 'epoch': 2}), TrainingResult(checkpoint=Checkpoint(filesystem=local, path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-40.273211), metrics={'loss': 0.15193207561969757, 'epoch': 3}), TrainingResult(checkpoint=Checkpoint(filesystem=local, path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-35-56.178119), metrics={'loss': 0.17416314780712128, 'epoch': 4}), TrainingResult(checkpoint=Checkpoint(filesystem=local, path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-36-12.547310), metrics={'loss': 0.2924661934375763, 'epoch': 5}), TrainingResult(checkpoint=Checkpoint(filesystem=local, path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-36-28.538090), metrics={'loss': 0.18640762567520142, 'epoch': 6}), TrainingResult(checkpoint=Checkpoint(filesystem=local, path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-36-44.583228), metrics={'loss': 0.12567029893398285, 'epoch': 7}), TrainingResult(checkpoint=Checkpoint(filesystem=local, path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-37-00.540405), metrics={'loss': 0.1620682030916214, 'epoch': 8}), TrainingResult(checkpoint=Checkpoint(filesystem=local, path=/mnt/cluster_storage/my_run_name/checkpoint_2025-08-04_17-37-17.129973), metrics={'loss': 0.07022886723279953, 'epoch': 9})] ` I also modified all the Ray Train v2 unit tests that call `ray.train.report`: * `test_persistence` also verifies that `get_all_reported_checkpoints` works on resumption * `test_data_parallel_trainer` verifies that `get_all_reported_checkpoints` stalls until all workers report. I also verified that `get_all_reported_checkpoints` produced similar output when called from Tune + Train. I tried to test that `get_all_reported_checkpoints` finished even with graceful abort but was unable to create such a scenario since `get_all_reported_checkpoints` returns very quickly and each `report` forms a barrier. --------- Signed-off-by: Timothy Seah <tseah@anyscale.com>
1 parent 686cd6e commit ea27046

17 files changed

+252
-31
lines changed

ci/lint/pydoclint-baseline.txt

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2006,14 +2006,6 @@ python/ray/train/v2/_internal/callbacks/accelerators.py
20062006
python/ray/train/v2/_internal/execution/checkpoint/checkpoint_manager.py
20072007
DOC103: Method `CheckpointManager.register_checkpoint`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [checkpoint_result: _TrainingResult]. Arguments in the docstring but not in the function signature: [checkpoint: ].
20082008
--------------------
2009-
python/ray/train/v2/_internal/execution/context.py
2010-
DOC101: Method `TrainContext._save_checkpoint`: Docstring contains fewer arguments than in function signature.
2011-
DOC103: Method `TrainContext._save_checkpoint`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [checkpoint: Optional[Checkpoint], checkpoint_dir_name: str, metrics: Dict[str, Any]].
2012-
--------------------
2013-
python/ray/train/v2/_internal/execution/controller/controller.py
2014-
DOC101: Method `TrainController._start_worker_group`: Docstring contains fewer arguments than in function signature.
2015-
DOC103: Method `TrainController._start_worker_group`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [num_workers: int, resources_per_worker: dict].
2016-
--------------------
20172009
python/ray/train/v2/_internal/execution/storage.py
20182010
DOC101: Method `_ExcludingLocalFilesystem.__init__`: Docstring contains fewer arguments than in function signature.
20192011
DOC103: Method `_ExcludingLocalFilesystem.__init__`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [**kwargs: ].

python/ray/train/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,10 @@
3434
RunConfig,
3535
ScalingConfig,
3636
)
37+
from ray.train.v2.api.reported_checkpoint import ReportedCheckpoint # noqa: F811
3738
from ray.train.v2.api.result import Result # noqa: F811
3839
from ray.train.v2.api.train_fn_utils import ( # noqa: F811
40+
get_all_reported_checkpoints,
3941
get_checkpoint,
4042
get_context,
4143
get_dataset_shard,
@@ -76,9 +78,14 @@
7678
SyncConfig.__module__ = "ray.train"
7779
TrainingIterator.__module__ = "ray.train"
7880

81+
# TODO: consider implementing these in v1 and raising ImportError instead.
7982
if is_v2_enabled():
8083
__all__.append("UserCallback")
8184
UserCallback.__module__ = "ray.train"
85+
__all__.append("get_all_reported_checkpoints")
86+
get_all_reported_checkpoints.__module__ = "ray.train"
87+
__all__.append("ReportedCheckpoint")
88+
ReportedCheckpoint.__module__ = "ray.train"
8289

8390

8491
# DO NOT ADD ANYTHING AFTER THIS LINE.

python/ray/train/v2/_internal/execution/checkpoint/checkpoint_manager.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import logging
23
from typing import Any, Dict, List, Optional
34

@@ -16,6 +17,7 @@
1617
from ray.train.v2._internal.execution.context import StorageContext
1718
from ray.train.v2._internal.execution.storage import _delete_fs_path, _exists_at_fs_path
1819
from ray.train.v2._internal.execution.worker_group import Worker
20+
from ray.train.v2.api.reported_checkpoint import ReportedCheckpoint
1921

2022
try:
2123
from pydantic import BaseModel
@@ -81,6 +83,12 @@ def __init__(
8183
):
8284
self._storage_context = storage_context
8385
self._checkpoint_config = checkpoint_config
86+
87+
# This tracks the number of report calls that have been processed
88+
# for the current worker group.
89+
self._num_report_calls = 0
90+
91+
self._condition = asyncio.Condition()
8492
super().__init__(checkpoint_config)
8593
# If the snapshot is found, the checkpoint manager will restore its state.
8694
self._maybe_load_state_from_storage()
@@ -139,6 +147,14 @@ def register_checkpoint(self, checkpoint_result: _TrainingResult):
139147
logger.debug("Deleting checkpoint: ", checkpoint)
140148
_delete_fs_path(fs=checkpoint.filesystem, fs_path=checkpoint.path)
141149

150+
self._num_report_calls += 1
151+
152+
async def async_notify():
153+
async with self._condition:
154+
self._condition.notify_all()
155+
156+
asyncio.create_task(async_notify())
157+
142158
# --------------------------
143159
# CheckpointManager state
144160
# --------------------------
@@ -267,6 +283,7 @@ def after_report(
267283
self, metrics: List[Dict[str, Any]], checkpoint: Optional[Checkpoint]
268284
):
269285
if not checkpoint:
286+
self._num_report_calls += 1
270287
return
271288

272289
rank_0_metrics = metrics[0]
@@ -279,9 +296,31 @@ def after_report(
279296
# --------------------------
280297

281298
def before_init_train_context(self, workers: List[Worker]) -> Dict[str, List[Any]]:
299+
self._num_report_calls = 0
282300
latest_checkpoint = (
283301
self.latest_checkpoint_result.checkpoint
284302
if self.latest_checkpoint_result
285303
else None
286304
)
287-
return {"checkpoint": [latest_checkpoint] * len(workers)}
305+
train_context_args = {
306+
"checkpoint": [latest_checkpoint] * len(workers),
307+
}
308+
return train_context_args
309+
310+
async def get_all_reported_checkpoints(
311+
self, expected_num_report_calls: int
312+
) -> List[ReportedCheckpoint]:
313+
"""Once expected_num_checkpoints are reported, return the ReportedCheckpoints."""
314+
async with self._condition:
315+
await self._condition.wait_for(
316+
lambda: self._num_report_calls == expected_num_report_calls
317+
)
318+
# TODO: might be nice for CheckpointManager to manage ReportedCheckpoint
319+
# instead of _TrainingResult but that is a large refactor.
320+
return [
321+
ReportedCheckpoint(
322+
checkpoint=tr.checkpoint,
323+
metrics=tr.metrics,
324+
)
325+
for tr in self._checkpoint_results
326+
]

python/ray/train/v2/_internal/execution/context.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from typing import TYPE_CHECKING, Any, Dict, List, Optional
88

99
import ray
10+
from ray.actor import ActorHandle
1011
from ray.data import DataIterator, Dataset
11-
from ray.train import BackendConfig, Checkpoint, DataConfig
1212
from ray.train._internal import session
1313
from ray.train._internal.session import _TrainingResult
1414
from ray.train.v2._internal.execution.checkpoint.sync_actor import SynchronizationActor
@@ -17,12 +17,14 @@
1717
from ray.train.v2.api.config import RunConfig, ScalingConfig
1818

1919
if TYPE_CHECKING:
20+
from ray.train import BackendConfig, Checkpoint, DataConfig
2021
from ray.train.v2._internal.data_integration.interfaces import (
2122
DatasetShardMetadata,
2223
DatasetShardProvider,
2324
)
2425
from ray.train.v2._internal.execution.callback import TrainContextCallback
2526
from ray.train.v2._internal.execution.worker_group.thread_runner import ThreadRunner
27+
from ray.train.v2.api.reported_checkpoint import ReportedCheckpoint
2628

2729

2830
logger = logging.getLogger(__file__)
@@ -45,13 +47,13 @@ class TrainRunContext:
4547
scaling_config: ScalingConfig
4648

4749
# The configuration for the training backend (e.g., PyTorch, XGBoost).
48-
backend_config: BackendConfig
50+
backend_config: "BackendConfig"
4951

5052
# The datasets used in the current training run.
5153
datasets: Dict[str, Dataset]
5254

5355
# The configuration for dataset ingestion and sharding.
54-
dataset_config: DataConfig
56+
dataset_config: "DataConfig"
5557

5658
def get_run_config(self) -> RunConfig:
5759
"""Returns the run config of the current training run."""
@@ -96,8 +98,11 @@ class TrainContext:
9698
distributed_context: DistributedContext
9799
execution_context: ExecutionContext
98100
storage_context: StorageContext
101+
controller_actor: ActorHandle
102+
99103
dataset_shard_provider: "DatasetShardProvider"
100-
checkpoint: Optional[Checkpoint] = None
104+
checkpoint: Optional["Checkpoint"] = None
105+
num_report_calls: int = 0
101106

102107
@_copy_doc(session.get_experiment_name)
103108
def get_experiment_name(self) -> str:
@@ -137,6 +142,13 @@ def get_synchronization_actor(self):
137142
def get_checkpoint(self):
138143
return self.checkpoint
139144

145+
def get_all_reported_checkpoints(self) -> List["ReportedCheckpoint"]:
146+
return ray.get(
147+
self.controller_actor.get_all_reported_checkpoints.remote(
148+
self.num_report_calls
149+
)
150+
)
151+
140152
def get_dataset_shard(self, dataset_info: "DatasetShardMetadata") -> DataIterator:
141153
"""Returns the :class:`ray.data.DataIterator` shard for this worker.
142154
@@ -189,10 +201,15 @@ def _save_checkpoint(
189201
self,
190202
checkpoint_dir_name: str,
191203
metrics: Dict[str, Any],
192-
checkpoint: Optional[Checkpoint] = None,
204+
checkpoint: Optional["Checkpoint"] = None,
193205
) -> _TrainingResult:
194206
"""Save the checkpoint to remote storage.
195207
208+
Args:
209+
checkpoint_dir_name: The checkpoint dir to persist to.
210+
metrics: The metrics to report.
211+
checkpoint: The checkpoint to report.
212+
196213
Returns:
197214
The training result object containing the persisted checkpoint.
198215
"""
@@ -212,7 +229,7 @@ def _save_checkpoint(
212229
def report(
213230
self,
214231
metrics: Dict[str, Any],
215-
checkpoint: Optional[Checkpoint] = None,
232+
checkpoint: Optional["Checkpoint"] = None,
216233
checkpoint_dir_name: Optional[str] = None,
217234
) -> None:
218235
"""
@@ -265,6 +282,7 @@ def report(
265282
# TODO (hpguo): Add a metrics to track the blocking time waiting for the
266283
# training result to be consumed by the controller.
267284
self.get_result_queue().put(training_result)
285+
self.num_report_calls += 1
268286

269287

270288
# The global variable holding the current TrainContext

python/ray/train/v2/_internal/execution/controller/controller.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
import uuid
55
from dataclasses import dataclass
6-
from typing import Callable, List, Optional, Union
6+
from typing import TYPE_CHECKING, Callable, List, Optional, Union
77

88
import pandas as pd
99

@@ -67,6 +67,10 @@
6767
)
6868
from ray.train.v2.api.result import Result
6969

70+
if TYPE_CHECKING:
71+
from ray.train.v2.api.reported_checkpoint import ReportedCheckpoint
72+
73+
7074
logger = logging.getLogger(__name__)
7175

7276

@@ -275,6 +279,10 @@ def _start_worker_group(
275279
) -> Optional[ControllerError]:
276280
"""Start the worker group and launch the train function.
277281
282+
Args:
283+
num_workers: The number of workers to start.
284+
resources_per_worker: The resources per worker to start.
285+
278286
Returns:
279287
None if the worker group was successfully started,
280288
ControllerError if the worker group failed to start.
@@ -537,7 +545,6 @@ def get_result(self) -> Result:
537545
raise ValueError(
538546
f"Cannot get result when controller is in state {controller_state}"
539547
)
540-
541548
return self._build_result()
542549

543550
def get_training_failed_error(self) -> Optional[TrainingFailedError]:
@@ -553,3 +560,10 @@ def get_training_failed_error(self) -> Optional[TrainingFailedError]:
553560
return controller_state.training_failed_error
554561

555562
return None
563+
564+
async def get_all_reported_checkpoints(
565+
self, expected_num_report_calls: int
566+
) -> List["ReportedCheckpoint"]:
567+
return await self._checkpoint_manager.get_all_reported_checkpoints(
568+
expected_num_report_calls
569+
)

python/ray/train/v2/_internal/execution/train_fn_utils.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
import threading
2-
from typing import Any, Dict, Optional
2+
from typing import TYPE_CHECKING, Any, Dict, List, Optional
33

44
from ray.data import DataIterator
5-
from ray.train import Checkpoint
65
from ray.train.v2._internal.execution import collective_impl
76
from ray.train.v2._internal.execution.context import (
87
get_train_context as get_internal_train_context,
98
)
109
from ray.train.v2.api.context import TrainContext as ExternalTrainContext
1110

11+
if TYPE_CHECKING:
12+
from ray.train import Checkpoint
13+
from ray.train.v2.api.reported_checkpoint import ReportedCheckpoint
14+
1215

1316
class TrainFnUtils:
1417
"""Utility class providing an abstraction layer between user-facing APIs
@@ -21,7 +24,7 @@ class TrainFnUtils:
2124
def report(
2225
self,
2326
metrics: Dict[str, Any],
24-
checkpoint: Optional[Checkpoint] = None,
27+
checkpoint: Optional["Checkpoint"] = None,
2528
checkpoint_dir_name: Optional[str] = None,
2629
) -> None:
2730
"""Upload checkpoint to remote storage and put a training result on the result queue.
@@ -46,6 +49,15 @@ def get_checkpoint(self):
4649
"""
4750
return get_internal_train_context().get_checkpoint()
4851

52+
def get_all_reported_checkpoints(self) -> List["ReportedCheckpoint"]:
53+
"""Get all the checkpoints reported by the workers.
54+
55+
Returns:
56+
A list of ReportedCheckpoint objects that represent the checkpoints and
57+
corresponding metrics reported by the workers.
58+
"""
59+
return get_internal_train_context().get_all_reported_checkpoints()
60+
4961
def get_dataset_shard(self, dataset_name: str) -> DataIterator:
5062
"""Get the dataset shard for this worker.
5163

python/ray/train/v2/_internal/execution/worker_group/worker.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ def init_train_context(
194194
synchronization_actor: SynchronizationActor,
195195
storage_context: StorageContext,
196196
worker_callbacks: List[Union[WorkerCallback, TrainContextCallback]],
197+
controller_actor: ActorHandle,
197198
dataset_shard_provider: Optional["DatasetShardProvider"] = None,
198199
checkpoint: Optional[Checkpoint] = None,
199200
):
@@ -213,6 +214,7 @@ def init_train_context(
213214
train_context_callbacks=context_callbacks_to_propagate,
214215
),
215216
storage_context=storage_context,
217+
controller_actor=controller_actor,
216218
checkpoint=checkpoint,
217219
dataset_shard_provider=dataset_shard_provider,
218220
)

python/ray/train/v2/_internal/execution/worker_group/worker_group.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,7 @@ def _init_train_context_on_workers(
437437
synchronization_actor=sync_actor,
438438
storage_context=self._storage_context,
439439
worker_callbacks=self._worker_callbacks_to_propagate,
440+
controller_actor=ray.get_runtime_context().current_actor,
440441
**{
441442
arg: arg_values[i] for arg, arg_values in train_context_args.items()
442443
},
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from dataclasses import dataclass
2+
from typing import TYPE_CHECKING, Any, Dict
3+
4+
from ray.util.annotations import PublicAPI
5+
6+
if TYPE_CHECKING:
7+
from ray.train import Checkpoint
8+
9+
10+
@dataclass
11+
@PublicAPI(stability="alpha")
12+
class ReportedCheckpoint:
13+
"""A user-reported checkpoint and its associated metrics.
14+
15+
Attributes:
16+
checkpoint: The checkpoint reported by the user.
17+
metrics: The metrics associated with that checkpoint.
18+
"""
19+
20+
checkpoint: "Checkpoint"
21+
metrics: Dict[str, Any]

0 commit comments

Comments
 (0)