Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
104cb2b
CheckpointManager and Worker both count checkpoints
TimothySeah Jul 12, 2025
7410f37
Merge remote-tracking branch 'upstream/master' into tseah/get-checkpo…
TimothySeah Aug 1, 2025
edf70f6
Do not assert str equality
TimothySeah Aug 3, 2025
6158667
Finish implementation + add 1 unit test (not done yet)
TimothySeah Aug 4, 2025
2f29f68
add test_data_parallel_trainer test case
TimothySeah Aug 4, 2025
26c5535
rename TrainingResult to ReportedCheckpoint + add publicapi docstring
TimothySeah Aug 5, 2025
08b318c
fix v1 import
TimothySeah Aug 5, 2025
e2617f0
move notify to async method
TimothySeah Aug 8, 2025
2132dd6
Address some comments e.g. do not save report count in state, mock ac…
TimothySeah Aug 9, 2025
0e4f5d0
address more comments
TimothySeah Aug 10, 2025
3ef882b
Merge remote-tracking branch 'upstream/master' into tseah/get-checkpo…
TimothySeah Aug 18, 2025
66fd5a8
remove unnecessary asyncio and make controller actor fully required
TimothySeah Aug 18, 2025
ffe4397
WorkerGroup passes current actor to workers + mocking to test that
TimothySeah Aug 19, 2025
5f6af2c
document corner case
TimothySeah Aug 19, 2025
6a2e824
[train][doc] Document get_all_reported_checkpoints and ReportedCheckp…
TimothySeah Aug 20, 2025
fa71eba
Merge pull request #1 from TimothySeah/tseah/get-all-reported-checkpo…
TimothySeah Aug 20, 2025
1990be3
address pr comments and fix ci failure
TimothySeah Aug 27, 2025
7767384
add comment to num_report_calls as suggested
TimothySeah Aug 27, 2025
2cd9412
remove unnecessary line
TimothySeah Aug 27, 2025
56a0537
Merge remote-tracking branch 'upstream/master' into tseah/get-checkpo…
TimothySeah Aug 27, 2025
e3a0b02
try adding publicapi annotation
TimothySeah Aug 28, 2025
c8dfdc5
address pr comments
TimothySeah Aug 28, 2025
642ebcc
try different import
TimothySeah Aug 28, 2025
f8ded64
remove doc changes; will add in future pr
TimothySeah Aug 30, 2025
0f0eccb
Merge remote-tracking branch 'upstream/master' into tseah/get-checkpo…
TimothySeah Aug 30, 2025
1bb190c
fix unit tests
TimothySeah Aug 30, 2025
f9b75f0
always import ray.train._checkpoint.Checkpoint
TimothySeah Sep 1, 2025
4b79eaf
try adding docs again
TimothySeah Sep 1, 2025
e705cf1
Revert "always import ray.train._checkpoint.Checkpoint"
TimothySeah Sep 1, 2025
6f1c122
Use TYPE_CHECKING on Checkpoint, ReportedCheckpoint, and some related…
TimothySeah Sep 1, 2025
17f9503
remove outdated pydoclint errors
TimothySeah Sep 1, 2025
8d06d4c
Revert "try adding docs again"
TimothySeah Sep 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions ci/lint/pydoclint-baseline.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2006,14 +2006,6 @@ python/ray/train/v2/_internal/callbacks/accelerators.py
python/ray/train/v2/_internal/execution/checkpoint/checkpoint_manager.py
DOC103: Method `CheckpointManager.register_checkpoint`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [checkpoint_result: _TrainingResult]. Arguments in the docstring but not in the function signature: [checkpoint: ].
--------------------
python/ray/train/v2/_internal/execution/context.py
DOC101: Method `TrainContext._save_checkpoint`: Docstring contains fewer arguments than in function signature.
DOC103: Method `TrainContext._save_checkpoint`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [checkpoint: Optional[Checkpoint], checkpoint_dir_name: str, metrics: Dict[str, Any]].
--------------------
python/ray/train/v2/_internal/execution/controller/controller.py
DOC101: Method `TrainController._start_worker_group`: Docstring contains fewer arguments than in function signature.
DOC103: Method `TrainController._start_worker_group`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [num_workers: int, resources_per_worker: dict].
--------------------
python/ray/train/v2/_internal/execution/storage.py
DOC101: Method `_ExcludingLocalFilesystem.__init__`: Docstring contains fewer arguments than in function signature.
DOC103: Method `_ExcludingLocalFilesystem.__init__`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [**kwargs: ].
Expand Down
7 changes: 7 additions & 0 deletions python/ray/train/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@
RunConfig,
ScalingConfig,
)
from ray.train.v2.api.reported_checkpoint import ReportedCheckpoint # noqa: F811
from ray.train.v2.api.result import Result # noqa: F811
from ray.train.v2.api.train_fn_utils import ( # noqa: F811
get_all_reported_checkpoints,
get_checkpoint,
get_context,
get_dataset_shard,
Expand Down Expand Up @@ -76,9 +78,14 @@
SyncConfig.__module__ = "ray.train"
TrainingIterator.__module__ = "ray.train"

# TODO: consider implementing these in v1 and raising ImportError instead.
if is_v2_enabled():
__all__.append("UserCallback")
UserCallback.__module__ = "ray.train"
__all__.append("get_all_reported_checkpoints")
get_all_reported_checkpoints.__module__ = "ray.train"
__all__.append("ReportedCheckpoint")
ReportedCheckpoint.__module__ = "ray.train"


# DO NOT ADD ANYTHING AFTER THIS LINE.
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import logging
from typing import Any, Dict, List, Optional

Expand All @@ -16,6 +17,7 @@
from ray.train.v2._internal.execution.context import StorageContext
from ray.train.v2._internal.execution.storage import _delete_fs_path, _exists_at_fs_path
from ray.train.v2._internal.execution.worker_group import Worker
from ray.train.v2.api.reported_checkpoint import ReportedCheckpoint

try:
from pydantic import BaseModel
Expand Down Expand Up @@ -81,6 +83,12 @@ def __init__(
):
self._storage_context = storage_context
self._checkpoint_config = checkpoint_config

# This tracks the number of report calls that have been processed
# for the current worker group.
self._num_report_calls = 0

self._condition = asyncio.Condition()
super().__init__(checkpoint_config)
# If the snapshot is found, the checkpoint manager will restore its state.
self._maybe_load_state_from_storage()
Expand Down Expand Up @@ -139,6 +147,14 @@ def register_checkpoint(self, checkpoint_result: _TrainingResult):
logger.debug("Deleting checkpoint: ", checkpoint)
_delete_fs_path(fs=checkpoint.filesystem, fs_path=checkpoint.path)

self._num_report_calls += 1

async def async_notify():
async with self._condition:
self._condition.notify_all()

asyncio.create_task(async_notify())

# --------------------------
# CheckpointManager state
# --------------------------
Expand Down Expand Up @@ -267,6 +283,7 @@ def after_report(
self, metrics: List[Dict[str, Any]], checkpoint: Optional[Checkpoint]
):
if not checkpoint:
self._num_report_calls += 1
return

rank_0_metrics = metrics[0]
Expand All @@ -279,9 +296,31 @@ def after_report(
# --------------------------

def before_init_train_context(self, workers: List[Worker]) -> Dict[str, List[Any]]:
self._num_report_calls = 0
latest_checkpoint = (
self.latest_checkpoint_result.checkpoint
if self.latest_checkpoint_result
else None
)
return {"checkpoint": [latest_checkpoint] * len(workers)}
train_context_args = {
"checkpoint": [latest_checkpoint] * len(workers),
}
return train_context_args

async def get_all_reported_checkpoints(
self, expected_num_report_calls: int
) -> List[ReportedCheckpoint]:
"""Once expected_num_checkpoints are reported, return the ReportedCheckpoints."""
async with self._condition:
await self._condition.wait_for(
lambda: self._num_report_calls == expected_num_report_calls
)
# TODO: might be nice for CheckpointManager to manage ReportedCheckpoint
# instead of _TrainingResult but that is a large refactor.
return [
ReportedCheckpoint(
checkpoint=tr.checkpoint,
metrics=tr.metrics,
)
for tr in self._checkpoint_results
]
30 changes: 24 additions & 6 deletions python/ray/train/v2/_internal/execution/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional

import ray
from ray.actor import ActorHandle
from ray.data import DataIterator, Dataset
from ray.train import BackendConfig, Checkpoint, DataConfig
from ray.train._internal import session
from ray.train._internal.session import _TrainingResult
from ray.train.v2._internal.execution.checkpoint.sync_actor import SynchronizationActor
Expand All @@ -17,12 +17,14 @@
from ray.train.v2.api.config import RunConfig, ScalingConfig

if TYPE_CHECKING:
from ray.train import BackendConfig, Checkpoint, DataConfig
from ray.train.v2._internal.data_integration.interfaces import (
DatasetShardMetadata,
DatasetShardProvider,
)
from ray.train.v2._internal.execution.callback import TrainContextCallback
from ray.train.v2._internal.execution.worker_group.thread_runner import ThreadRunner
from ray.train.v2.api.reported_checkpoint import ReportedCheckpoint


logger = logging.getLogger(__file__)
Expand All @@ -45,13 +47,13 @@ class TrainRunContext:
scaling_config: ScalingConfig

# The configuration for the training backend (e.g., PyTorch, XGBoost).
backend_config: BackendConfig
backend_config: "BackendConfig"

# The datasets used in the current training run.
datasets: Dict[str, Dataset]

# The configuration for dataset ingestion and sharding.
dataset_config: DataConfig
dataset_config: "DataConfig"

def get_run_config(self) -> RunConfig:
"""Returns the run config of the current training run."""
Expand Down Expand Up @@ -96,8 +98,11 @@ class TrainContext:
distributed_context: DistributedContext
execution_context: ExecutionContext
storage_context: StorageContext
controller_actor: ActorHandle

dataset_shard_provider: "DatasetShardProvider"
checkpoint: Optional[Checkpoint] = None
checkpoint: Optional["Checkpoint"] = None
num_report_calls: int = 0

@_copy_doc(session.get_experiment_name)
def get_experiment_name(self) -> str:
Expand Down Expand Up @@ -137,6 +142,13 @@ def get_synchronization_actor(self):
def get_checkpoint(self):
return self.checkpoint

def get_all_reported_checkpoints(self) -> List["ReportedCheckpoint"]:
return ray.get(
self.controller_actor.get_all_reported_checkpoints.remote(
self.num_report_calls
)
)

def get_dataset_shard(self, dataset_info: "DatasetShardMetadata") -> DataIterator:
"""Returns the :class:`ray.data.DataIterator` shard for this worker.

Expand Down Expand Up @@ -189,10 +201,15 @@ def _save_checkpoint(
self,
checkpoint_dir_name: str,
metrics: Dict[str, Any],
checkpoint: Optional[Checkpoint] = None,
checkpoint: Optional["Checkpoint"] = None,
) -> _TrainingResult:
"""Save the checkpoint to remote storage.

Args:
checkpoint_dir_name: The checkpoint dir to persist to.
metrics: The metrics to report.
checkpoint: The checkpoint to report.

Returns:
The training result object containing the persisted checkpoint.
"""
Expand All @@ -212,7 +229,7 @@ def _save_checkpoint(
def report(
self,
metrics: Dict[str, Any],
checkpoint: Optional[Checkpoint] = None,
checkpoint: Optional["Checkpoint"] = None,
checkpoint_dir_name: Optional[str] = None,
) -> None:
"""
Expand Down Expand Up @@ -265,6 +282,7 @@ def report(
# TODO (hpguo): Add a metrics to track the blocking time waiting for the
# training result to be consumed by the controller.
self.get_result_queue().put(training_result)
self.num_report_calls += 1


# The global variable holding the current TrainContext
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import uuid
from dataclasses import dataclass
from typing import Callable, List, Optional, Union
from typing import TYPE_CHECKING, Callable, List, Optional, Union

import pandas as pd

Expand Down Expand Up @@ -67,6 +67,10 @@
)
from ray.train.v2.api.result import Result

if TYPE_CHECKING:
from ray.train.v2.api.reported_checkpoint import ReportedCheckpoint


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -275,6 +279,10 @@ def _start_worker_group(
) -> Optional[ControllerError]:
"""Start the worker group and launch the train function.

Args:
num_workers: The number of workers to start.
resources_per_worker: The resources per worker to start.

Returns:
None if the worker group was successfully started,
ControllerError if the worker group failed to start.
Expand Down Expand Up @@ -537,7 +545,6 @@ def get_result(self) -> Result:
raise ValueError(
f"Cannot get result when controller is in state {controller_state}"
)

return self._build_result()

def get_training_failed_error(self) -> Optional[TrainingFailedError]:
Expand All @@ -553,3 +560,10 @@ def get_training_failed_error(self) -> Optional[TrainingFailedError]:
return controller_state.training_failed_error

return None

async def get_all_reported_checkpoints(
self, expected_num_report_calls: int
) -> List["ReportedCheckpoint"]:
return await self._checkpoint_manager.get_all_reported_checkpoints(
expected_num_report_calls
)
18 changes: 15 additions & 3 deletions python/ray/train/v2/_internal/execution/train_fn_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import threading
from typing import Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from ray.data import DataIterator
from ray.train import Checkpoint
from ray.train.v2._internal.execution import collective_impl
from ray.train.v2._internal.execution.context import (
get_train_context as get_internal_train_context,
)
from ray.train.v2.api.context import TrainContext as ExternalTrainContext

if TYPE_CHECKING:
from ray.train import Checkpoint
from ray.train.v2.api.reported_checkpoint import ReportedCheckpoint


class TrainFnUtils:
"""Utility class providing an abstraction layer between user-facing APIs
Expand All @@ -21,7 +24,7 @@ class TrainFnUtils:
def report(
self,
metrics: Dict[str, Any],
checkpoint: Optional[Checkpoint] = None,
checkpoint: Optional["Checkpoint"] = None,
checkpoint_dir_name: Optional[str] = None,
) -> None:
"""Upload checkpoint to remote storage and put a training result on the result queue.
Expand All @@ -46,6 +49,15 @@ def get_checkpoint(self):
"""
return get_internal_train_context().get_checkpoint()

def get_all_reported_checkpoints(self) -> List["ReportedCheckpoint"]:
"""Get all the checkpoints reported by the workers.

Returns:
A list of ReportedCheckpoint objects that represent the checkpoints and
corresponding metrics reported by the workers.
"""
return get_internal_train_context().get_all_reported_checkpoints()

def get_dataset_shard(self, dataset_name: str) -> DataIterator:
"""Get the dataset shard for this worker.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def init_train_context(
synchronization_actor: SynchronizationActor,
storage_context: StorageContext,
worker_callbacks: List[Union[WorkerCallback, TrainContextCallback]],
controller_actor: ActorHandle,
dataset_shard_provider: Optional["DatasetShardProvider"] = None,
checkpoint: Optional[Checkpoint] = None,
):
Expand All @@ -213,6 +214,7 @@ def init_train_context(
train_context_callbacks=context_callbacks_to_propagate,
),
storage_context=storage_context,
controller_actor=controller_actor,
checkpoint=checkpoint,
dataset_shard_provider=dataset_shard_provider,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ def _init_train_context_on_workers(
synchronization_actor=sync_actor,
storage_context=self._storage_context,
worker_callbacks=self._worker_callbacks_to_propagate,
controller_actor=ray.get_runtime_context().current_actor,
**{
arg: arg_values[i] for arg, arg_values in train_context_args.items()
},
Expand Down
21 changes: 21 additions & 0 deletions python/ray/train/v2/api/reported_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict

from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
from ray.train import Checkpoint


@dataclass
@PublicAPI(stability="alpha")
class ReportedCheckpoint:
"""A user-reported checkpoint and its associated metrics.

Attributes:
checkpoint: The checkpoint reported by the user.
metrics: The metrics associated with that checkpoint.
"""

checkpoint: "Checkpoint"
metrics: Dict[str, Any]
Loading