Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
104cb2b
CheckpointManager and Worker both count checkpoints
TimothySeah Jul 12, 2025
7410f37
Merge remote-tracking branch 'upstream/master' into tseah/get-checkpo…
TimothySeah Aug 1, 2025
edf70f6
Do not assert str equality
TimothySeah Aug 3, 2025
6158667
Finish implementation + add 1 unit test (not done yet)
TimothySeah Aug 4, 2025
2f29f68
add test_data_parallel_trainer test case
TimothySeah Aug 4, 2025
26c5535
rename TrainingResult to ReportedCheckpoint + add publicapi docstring
TimothySeah Aug 5, 2025
08b318c
fix v1 import
TimothySeah Aug 5, 2025
e2617f0
move notify to async method
TimothySeah Aug 8, 2025
2132dd6
Address some comments e.g. do not save report count in state, mock ac…
TimothySeah Aug 9, 2025
0e4f5d0
address more comments
TimothySeah Aug 10, 2025
3ef882b
Merge remote-tracking branch 'upstream/master' into tseah/get-checkpo…
TimothySeah Aug 18, 2025
66fd5a8
remove unnecessary asyncio and make controller actor fully required
TimothySeah Aug 18, 2025
ffe4397
WorkerGroup passes current actor to workers + mocking to test that
TimothySeah Aug 19, 2025
5f6af2c
document corner case
TimothySeah Aug 19, 2025
6a2e824
[train][doc] Document get_all_reported_checkpoints and ReportedCheckp…
TimothySeah Aug 20, 2025
fa71eba
Merge pull request #1 from TimothySeah/tseah/get-all-reported-checkpo…
TimothySeah Aug 20, 2025
1990be3
address pr comments and fix ci failure
TimothySeah Aug 27, 2025
7767384
add comment to num_report_calls as suggested
TimothySeah Aug 27, 2025
2cd9412
remove unnecessary line
TimothySeah Aug 27, 2025
56a0537
Merge remote-tracking branch 'upstream/master' into tseah/get-checkpo…
TimothySeah Aug 27, 2025
e3a0b02
try adding publicapi annotation
TimothySeah Aug 28, 2025
c8dfdc5
address pr comments
TimothySeah Aug 28, 2025
642ebcc
try different import
TimothySeah Aug 28, 2025
f8ded64
remove doc changes; will add in future pr
TimothySeah Aug 30, 2025
0f0eccb
Merge remote-tracking branch 'upstream/master' into tseah/get-checkpo…
TimothySeah Aug 30, 2025
1bb190c
fix unit tests
TimothySeah Aug 30, 2025
f9b75f0
always import ray.train._checkpoint.Checkpoint
TimothySeah Sep 1, 2025
4b79eaf
try adding docs again
TimothySeah Sep 1, 2025
e705cf1
Revert "always import ray.train._checkpoint.Checkpoint"
TimothySeah Sep 1, 2025
6f1c122
Use TYPE_CHECKING on Checkpoint, ReportedCheckpoint, and some related…
TimothySeah Sep 1, 2025
17f9503
remove outdated pydoclint errors
TimothySeah Sep 1, 2025
8d06d4c
Revert "try adding docs again"
TimothySeah Sep 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from typing import Any, Dict, List, Optional

import ray
from ray.air.config import CheckpointConfig
from ray.train._checkpoint import Checkpoint
from ray.train._internal.checkpoint_manager import (
Expand Down Expand Up @@ -42,6 +43,7 @@ class _CheckpointManagerState(BaseModel):
version: int = 0
checkpoint_results: List[_TrainingResultState]
latest_checkpoint_result: Optional[_TrainingResultState]
num_reported_checkpoints: int


def _get_training_result_from_state(
Expand Down Expand Up @@ -81,6 +83,7 @@ def __init__(
):
self._storage_context = storage_context
self._checkpoint_config = checkpoint_config
self._num_reported_checkpoints = 0
super().__init__(checkpoint_config)
# If the snapshot is found, the checkpoint manager will restore its state.
self._maybe_load_state_from_storage()
Expand All @@ -96,6 +99,9 @@ def register_checkpoint(self, checkpoint_result: _TrainingResult):
Args:
checkpoint: Tracked checkpoint object to add to bookkeeping.
"""
self._num_reported_checkpoints += 1
# TODO: might be nice for CheckpointManager to manage ValidatedCheckpoint
# instead of _TrainingResult but that is a large refactor.
self._latest_checkpoint_result = checkpoint_result

if self._checkpoint_config.checkpoint_score_attribute is not None:
Expand Down Expand Up @@ -162,6 +168,7 @@ def _save_state(self) -> str:
manager_snapshot = _CheckpointManagerState(
checkpoint_results=checkpoint_results,
latest_checkpoint_result=latest_checkpoint_result,
num_reported_checkpoints=self._num_reported_checkpoints,
)
return manager_snapshot.model_dump_json()

Expand Down Expand Up @@ -190,6 +197,8 @@ def _load_state(self, json_state: str):
else None
)

self._num_reported_checkpoints = manager_snapshot.num_reported_checkpoints

def _maybe_load_state_from_storage(self):
"""Load the checkpoint manager state from storage.
If no snapshot is found, start with a clean state.
Expand Down Expand Up @@ -284,4 +293,12 @@ def before_init_train_context(self, workers: List[Worker]) -> Dict[str, List[Any
if self.latest_checkpoint_result
else None
)
return {"checkpoint": [latest_checkpoint] * len(workers)}
train_context_args = {
"checkpoint": [latest_checkpoint] * len(workers),
"num_reported_checkpoints": [self._num_reported_checkpoints] * len(workers),
}
if ray.get_runtime_context().get_actor_id() is not None:
train_context_args["controller_actor"] = [
ray.get_runtime_context().current_actor
] * len(workers)
return train_context_args
4 changes: 4 additions & 0 deletions python/ray/train/v2/_internal/execution/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional

import ray
from ray.actor import ActorHandle
from ray.data import DataIterator, Dataset
from ray.train import BackendConfig, Checkpoint, DataConfig
from ray.train._internal import session
Expand Down Expand Up @@ -94,6 +95,8 @@ class TrainContext:
storage_context: StorageContext
dataset_shards: Dict[str, DataIterator]
checkpoint: Optional[Checkpoint] = None
num_reported_checkpoints: int = 0
controller_actor: Optional[ActorHandle] = None

@_copy_doc(session.get_experiment_name)
def get_experiment_name(self) -> str:
Expand Down Expand Up @@ -267,6 +270,7 @@ def report(
# TODO (hpguo): Add a metrics to track the blocking time waiting for the
# training result to be consumed by the controller.
self.get_result_queue().put(training_result)
self.num_reported_checkpoints += 1


# The global variable holding the current TrainContext
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ def init_train_context(
worker_callbacks: List[Union[WorkerCallback, TrainContextCallback]],
dataset_shards: Dict[str, DataIterator] = None,
checkpoint: Optional[Checkpoint] = None,
num_reported_checkpoints: int = 0,
controller_actor: Optional[ActorHandle] = None,
):
self._callbacks = [c for c in worker_callbacks if isinstance(c, WorkerCallback)]
context_callbacks_to_propagate = [
Expand All @@ -209,6 +211,8 @@ def init_train_context(
storage_context=storage_context,
dataset_shards=dataset_shards or {},
checkpoint=checkpoint,
num_reported_checkpoints=num_reported_checkpoints,
controller_actor=controller_actor,
)
# Configure the train and root logger for the worker processes.
if ray_constants.env_bool(
Expand Down
17 changes: 17 additions & 0 deletions python/ray/train/v2/api/reported_result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from dataclasses import dataclass
from typing import Any, Dict

from ray.train import Checkpoint


@dataclass
class ValidatedCheckpoint:
"""A user-reported checkpoint and its associated metrics.

Attributes:
checkpoint: The checkpoint reported by the user.
metrics: The metrics associated with that checkpoint.
"""

checkpoint: Checkpoint
metrics: Dict[str, Any]
69 changes: 65 additions & 4 deletions python/ray/train/v2/tests/test_checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from ray.train.v2._internal.execution.storage import StorageContext
from ray.train.v2._internal.execution.worker_group import Worker
from ray.train.v2._internal.execution.worker_group.worker import ActorMetadata


@pytest.fixture(autouse=True, scope="module")
Expand Down Expand Up @@ -113,8 +114,9 @@ def test_save_load_state_equivalence(
)

# Register the training results into checkpoint manager
for tr in training_results:
for i, tr in enumerate(training_results):
checkpoint_manager.register_checkpoint(tr)
assert checkpoint_manager._num_reported_checkpoints == i + 1
loaded_checkpoint_manager = CheckpointManager(
storage_context=storage_context,
checkpoint_config=checkpoint_config,
Expand Down Expand Up @@ -158,17 +160,76 @@ def test_before_init_train_context(tmp_path):

# Assert without a checkpoint.
assert checkpoint_manager.before_init_train_context(workers) == {
"checkpoint": [None] * 4
"checkpoint": [None] * 4,
"num_reported_checkpoints": [0] * 4,
}

# Assert with a checkpoint
latest_checkpoint_result = _create_dummy_training_results(1, storage_context)[0]
checkpoint_manager._latest_checkpoint_result = latest_checkpoint_result
checkpoint_manager.register_checkpoint(latest_checkpoint_result)
assert checkpoint_manager.before_init_train_context(workers) == {
"checkpoint": [latest_checkpoint_result.checkpoint] * 4
"checkpoint": [latest_checkpoint_result.checkpoint] * 4,
"num_reported_checkpoints": [1] * 4,
}


def test_before_init_train_context_from_actor(tmp_path):
# Create checkpoint manager actor
storage_context = StorageContext(
storage_path=tmp_path,
experiment_dir_name="my_experiment_name",
)
checkpoint_manager_actor_cls = ray.remote(CheckpointManager)
checkpoint_manager = checkpoint_manager_actor_cls.remote(
storage_context=storage_context,
checkpoint_config=CheckpointConfig(),
)

# Create workers
@ray.remote
class DummyActor:
pass

workers = [
Worker(
actor=DummyActor.remote(),
metadata=ActorMetadata(
hostname="hostname",
node_id="node_id",
node_ip="node_ip",
pid="pid",
accelerator_ids={},
),
resources={},
)
for _ in range(4)
]

# Assert without a checkpoint.
assert ray.get(checkpoint_manager.before_init_train_context.remote(workers)) == {
"checkpoint": [None] * 4,
"num_reported_checkpoints": [0] * 4,
"controller_actor": [checkpoint_manager] * 4,
}

# Assert with a checkpoint.
latest_checkpoint_result = _create_dummy_training_results(1, storage_context)[0]
ray.get(checkpoint_manager.register_checkpoint.remote(latest_checkpoint_result))
train_context_args = ray.get(
checkpoint_manager.before_init_train_context.remote(workers)
)
assert train_context_args.keys() == {
"checkpoint",
"num_reported_checkpoints",
"controller_actor",
}
assert train_context_args["controller_actor"] == [checkpoint_manager] * 4
assert train_context_args["num_reported_checkpoints"] == [1] * 4
assert [str(checkpoint) for checkpoint in train_context_args["checkpoint"]] == [
str(latest_checkpoint_result.checkpoint)
] * 4


if __name__ == "__main__":
import sys

Expand Down