Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
309eb3a
first commit
xinyuangui2 Aug 11, 2025
0750c05
only support single process for now
xinyuangui2 Aug 11, 2025
4701f1d
rename some classes
xinyuangui2 Aug 13, 2025
69ea86c
rename some classes
xinyuangui2 Aug 13, 2025
b91d65e
fix unittest and update experiment name
xinyuangui2 Aug 13, 2025
6ea2acf
merge master
xinyuangui2 Aug 14, 2025
ef4871b
move to distributedtrainer
xinyuangui2 Aug 15, 2025
0cad925
update config
xinyuangui2 Aug 15, 2025
c4de5bb
fix some namings
xinyuangui2 Aug 15, 2025
879ebbe
Merge branch 'master' into use-fnutils-in-trainer
xinyuangui2 Aug 18, 2025
09e245d
add unittests for trainers
xinyuangui2 Aug 18, 2025
4ee1dda
fix v2 import for xgboost config
xinyuangui2 Aug 18, 2025
0fe6cc2
remove unused changes
xinyuangui2 Aug 18, 2025
a4217c8
clean
xinyuangui2 Aug 19, 2025
e2a888d
Merge branch 'master' into use-fnutils-in-trainer
xinyuangui2 Aug 20, 2025
e0bb04a
move local tests to one single file
xinyuangui2 Aug 20, 2025
2772212
add build file for test_local_mode
xinyuangui2 Aug 20, 2025
1dfbdb8
fix config
xinyuangui2 Aug 20, 2025
ea7991e
fix
xinyuangui2 Aug 20, 2025
3e919f5
Merge branch 'master' into use-fnutils-in-trainer
xinyuangui2 Aug 20, 2025
046cbe3
remove local_mode_controller as parameter
xinyuangui2 Aug 21, 2025
28f46c9
Merge branch 'master' into use-fnutils-in-trainer
xinyuangui2 Aug 21, 2025
849e533
fix field
xinyuangui2 Aug 21, 2025
0e139be
remove unused changes
xinyuangui2 Aug 21, 2025
1052f3b
Merge branch 'master' into use-fnutils-in-trainer
xinyuangui2 Aug 21, 2025
329db7a
Merge branch 'master' into use-fnutils-in-trainer
xinyuangui2 Aug 24, 2025
bdf8391
merge master
xinyuangui2 Aug 24, 2025
6cb513e
Merge branch 'master' into use-fnutils-in-trainer
xinyuangui2 Aug 25, 2025
cd6da46
Merge branch 'master' into use-fnutils-in-trainer
xinyuangui2 Aug 26, 2025
b87a5c7
resolve comments
xinyuangui2 Aug 26, 2025
29a1ce6
add one more local model log
xinyuangui2 Aug 26, 2025
56f7557
resolve comments
xinyuangui2 Aug 26, 2025
416913e
refactor the xgboostConfig to avoid circular import
xinyuangui2 Aug 26, 2025
9881ee4
Revert "refactor the xgboostConfig to avoid circular import"
xinyuangui2 Aug 27, 2025
48a161a
exclude xgboosttrainer from local mode for now
xinyuangui2 Aug 27, 2025
09ea48f
Merge branch 'master' into use-fnutils-in-trainer
xinyuangui2 Aug 27, 2025
4203f72
resolve comments
xinyuangui2 Aug 28, 2025
c62ecb6
resolve comments
xinyuangui2 Aug 28, 2025
5aac3d0
remove unneeded parameters
xinyuangui2 Aug 28, 2025
3551146
add xgboost into local mode
xinyuangui2 Aug 28, 2025
62c5f46
Merge branch 'master' into use-fnutils-in-trainer
xinyuangui2 Sep 2, 2025
f5fcaee
resolve comments
xinyuangui2 Sep 2, 2025
34390e6
Update .gitignore
xinyuangui2 Sep 2, 2025
067777a
Apply suggestions from code review
matthewdeng Sep 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/ray/train/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
accelerate,
backward,
enable_reproducibility,
get_device,
get_devices,
prepare_data_loader,
prepare_model,
prepare_optimizer,
Expand Down
102 changes: 102 additions & 0 deletions python/ray/train/v2/_internal/execution/local_mode_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import logging
from typing import Any, Callable, Dict, Optional

from ray.data import DataIterator
from ray.train import Checkpoint, Result
from ray.train.trainer import GenDataset
from ray.train.v2._internal.execution.train_fn_utils import (
TrainFnUtils,
get_train_fn_utils,
set_train_fn_utils,
)
from ray.train.v2._internal.util import date_str
from ray.train.v2.api.context import (
LocalModeTrainContext,
TrainContext as ExternalTrainContext,
)

logger = logging.getLogger(__name__)


class LocalModeTrainFnUtils(TrainFnUtils):
"""TrainFnUtils for local mode.
The training function will run in the same process.
"""

def __init__(
self,
experiment_name: str,
local_world_size: int,
local_rank: int,
dataset_shards: Optional[Dict[str, DataIterator]] = None,
):
self._context = LocalModeTrainContext(
experiment_name=experiment_name,
local_world_size=local_world_size,
local_rank=local_rank,
)
self._dataset_shards = dataset_shards
self._last_metrics = None
self._last_checkpoint = None

def report(
self,
metrics: Dict[str, Any],
checkpoint: Optional[Checkpoint] = None,
checkpoint_dir_name: Optional[str] = None,
) -> None:
self._last_metrics = metrics
self._last_checkpoint = checkpoint

def get_checkpoint(self) -> Optional[Checkpoint]:
return self._last_checkpoint

def get_dataset_shard(self, dataset_name: str) -> DataIterator:
assert (
self._dataset_shards is not None and dataset_name in self._dataset_shards
), f"Dataset shard {dataset_name} not found."
return self._dataset_shards[dataset_name]

def get_context(self) -> ExternalTrainContext:
return self._context

def is_running_in_distributed_mode(self) -> bool:
return False

def _get_last_metrics(self) -> Optional[Dict[str, Any]]:
"""Return the last metrics reported by the training function.
This function should only be called by BackendForLocalMode
"""
return self._last_metrics


class BackendForLocalMode:
def __init__(self, datasets: Optional[Dict[str, GenDataset]] = None):
if datasets is not None:
datasets = {k: v() if callable(v) else v for k, v in datasets.items()}

self.local_world_size = 1
self.local_rank = 0

set_train_fn_utils(
LocalModeTrainFnUtils(
experiment_name=self._get_experiment_name(),
local_world_size=self.local_world_size,
local_rank=self.local_rank,
dataset_shards=datasets,
)
)

def _get_experiment_name(self) -> str:
return f"local_training-{date_str()}"

def fit(self, train_func: Callable[[], None]) -> Result:
train_func()
train_fn_utils = get_train_fn_utils()
assert isinstance(train_fn_utils, LocalModeTrainFnUtils)
return Result(
metrics=train_fn_utils._get_last_metrics(),
checkpoint=train_fn_utils.get_checkpoint(),
path=None,
error=None,
)
64 changes: 53 additions & 11 deletions python/ray/train/v2/_internal/execution/train_fn_utils.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,29 @@
import threading
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional

from ray.data import DataIterator
from ray.train import Checkpoint
from ray.train.v2._internal.execution.context import (
get_train_context as get_internal_train_context,
)
from ray.train.v2.api.context import TrainContext as ExternalTrainContext
from ray.train.v2.api.context import (
DistributedTrainContext,
TrainContext as ExternalTrainContext,
)


class TrainFnUtils:
class TrainFnUtils(ABC):
"""Utility class providing an abstraction layer between user-facing APIs
and :class:`~ray.train.v2._internal.execution.context.TrainContext`.
and :class:`~ray.train.v2.api.context.TrainContext`.

It should be set before the users' training function is called, like training workers initialization.
It should be set before the users' training function is called.
For distributed mode, it is set during training workers initialization.
For local mode, it is set during the initialization of :class:`~ray.train.v2.api.data_parallel_trainer.DataParallelTrainer`.
This class can be patched if new user APIs behaviors is wanted.
"""

@abstractmethod
def report(
self,
metrics: Dict[str, Any],
Expand All @@ -33,20 +40,20 @@ def report(
be stored in the default storage path. If set, make sure
this value is unique for each iteration.
"""
return get_internal_train_context().report(
metrics, checkpoint, checkpoint_dir_name
)
pass

def get_checkpoint(self):
@abstractmethod
def get_checkpoint(self) -> Optional[Checkpoint]:
"""Get the latest checkpoint to resume training from.

Returns:
The latest checkpoint if available, None otherwise.
"""
return get_internal_train_context().get_checkpoint()
pass

@abstractmethod
def get_dataset_shard(self, dataset_name: str) -> DataIterator:
"""Get the dataset shard for this worker.
"""Get the dataset shard for this training process.

This method is used by the public API function :func:`ray.train.get_dataset_shard`.
Users should typically call ``ray.train.get_dataset_shard()`` instead of calling this method directly.
Expand All @@ -57,6 +64,38 @@ def get_dataset_shard(self, dataset_name: str) -> DataIterator:
Returns:
The DataIterator shard for this worker.
"""
pass

@abstractmethod
def get_context(self) -> ExternalTrainContext:
"""Get the TrainContext for this training process.
Different implmentation of TrainFnUtils will return different TrainContext.

Returns:
The train context for this training process.
"""
pass

@abstractmethod
def is_running_in_distributed_mode(self) -> bool:
pass


class DistributedTrainFnUtils(TrainFnUtils):
def report(
self,
metrics: Dict[str, Any],
checkpoint: Optional[Checkpoint] = None,
checkpoint_dir_name: Optional[str] = None,
) -> None:
return get_internal_train_context().report(
metrics, checkpoint, checkpoint_dir_name
)

def get_checkpoint(self):
return get_internal_train_context().get_checkpoint()

def get_dataset_shard(self, dataset_name: str) -> DataIterator:
from ray.train.v2._internal.callbacks.datasets import DatasetShardMetadata

dataset_info = DatasetShardMetadata(
Expand All @@ -66,7 +105,10 @@ def get_dataset_shard(self, dataset_name: str) -> DataIterator:
return get_internal_train_context().get_dataset_shard(dataset_info)

def get_context(self) -> ExternalTrainContext:
return ExternalTrainContext()
return DistributedTrainContext()

def is_running_in_distributed_mode(self) -> bool:
return True


_train_fn_utils: Optional[TrainFnUtils] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
)
from ray.train.v2._internal.execution.storage import StorageContext
from ray.train.v2._internal.execution.train_fn_utils import (
TrainFnUtils,
DistributedTrainFnUtils,
set_train_fn_utils,
)
from ray.train.v2._internal.execution.worker_group.poll import WorkerStatus
Expand Down Expand Up @@ -222,7 +222,7 @@ def init_train_context(
set_train_context(context)

# user facing train fn utils
set_train_fn_utils(TrainFnUtils())
set_train_fn_utils(DistributedTrainFnUtils())

for callback in self._callbacks:
callback.after_init_train_context()
12 changes: 11 additions & 1 deletion python/ray/train/v2/api/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional, Union
Expand Down Expand Up @@ -33,7 +34,9 @@ class ScalingConfig(ScalingConfigV1):
num_workers: The number of workers (Ray actors) to launch.
Each worker will reserve 1 CPU by default. The number of CPUs
reserved by each worker can be overridden with the
``resources_per_worker`` argument.
``resources_per_worker`` argument. If the number of workers is 0,
the training function will run in local mode, meaning the training
function runs in the same process.
use_gpu: If True, training will be done on GPUs (1 per worker).
Defaults to False. The number of GPUs reserved by each
worker can be overridden with the ``resources_per_worker``
Expand Down Expand Up @@ -119,6 +122,13 @@ def __post_init__(self):
"`use_tpu=True` and `num_workers` > 1."
)

if self.num_workers == 0:
warnings.warn(
"Running in local mode. The training function will run in the same process. "
"If you are using it and running into issues please file a report at "
"https://github.com/ray-project/ray/issues."
)

super().__post_init__()

@property
Expand Down
Loading