Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
309eb3a
first commit
xinyuangui2 Aug 11, 2025
0750c05
only support single process for now
xinyuangui2 Aug 11, 2025
4701f1d
rename some classes
xinyuangui2 Aug 13, 2025
69ea86c
rename some classes
xinyuangui2 Aug 13, 2025
b91d65e
fix unittest and update experiment name
xinyuangui2 Aug 13, 2025
6ea2acf
merge master
xinyuangui2 Aug 14, 2025
ef4871b
move to distributedtrainer
xinyuangui2 Aug 15, 2025
0cad925
update config
xinyuangui2 Aug 15, 2025
c4de5bb
fix some namings
xinyuangui2 Aug 15, 2025
879ebbe
Merge branch 'master' into use-fnutils-in-trainer
xinyuangui2 Aug 18, 2025
09e245d
add unittests for trainers
xinyuangui2 Aug 18, 2025
4ee1dda
fix v2 import for xgboost config
xinyuangui2 Aug 18, 2025
0fe6cc2
remove unused changes
xinyuangui2 Aug 18, 2025
a4217c8
clean
xinyuangui2 Aug 19, 2025
e2a888d
Merge branch 'master' into use-fnutils-in-trainer
xinyuangui2 Aug 20, 2025
e0bb04a
move local tests to one single file
xinyuangui2 Aug 20, 2025
2772212
add build file for test_local_mode
xinyuangui2 Aug 20, 2025
1dfbdb8
fix config
xinyuangui2 Aug 20, 2025
ea7991e
fix
xinyuangui2 Aug 20, 2025
3e919f5
Merge branch 'master' into use-fnutils-in-trainer
xinyuangui2 Aug 20, 2025
046cbe3
remove local_mode_controller as parameter
xinyuangui2 Aug 21, 2025
28f46c9
Merge branch 'master' into use-fnutils-in-trainer
xinyuangui2 Aug 21, 2025
849e533
fix field
xinyuangui2 Aug 21, 2025
0e139be
remove unused changes
xinyuangui2 Aug 21, 2025
1052f3b
Merge branch 'master' into use-fnutils-in-trainer
xinyuangui2 Aug 21, 2025
329db7a
Merge branch 'master' into use-fnutils-in-trainer
xinyuangui2 Aug 24, 2025
bdf8391
merge master
xinyuangui2 Aug 24, 2025
6cb513e
Merge branch 'master' into use-fnutils-in-trainer
xinyuangui2 Aug 25, 2025
cd6da46
Merge branch 'master' into use-fnutils-in-trainer
xinyuangui2 Aug 26, 2025
b87a5c7
resolve comments
xinyuangui2 Aug 26, 2025
29a1ce6
add one more local model log
xinyuangui2 Aug 26, 2025
56f7557
resolve comments
xinyuangui2 Aug 26, 2025
416913e
refactor the xgboostConfig to avoid circular import
xinyuangui2 Aug 26, 2025
9881ee4
Revert "refactor the xgboostConfig to avoid circular import"
xinyuangui2 Aug 27, 2025
48a161a
exclude xgboosttrainer from local mode for now
xinyuangui2 Aug 27, 2025
09ea48f
Merge branch 'master' into use-fnutils-in-trainer
xinyuangui2 Aug 27, 2025
4203f72
resolve comments
xinyuangui2 Aug 28, 2025
c62ecb6
resolve comments
xinyuangui2 Aug 28, 2025
5aac3d0
remove unneeded parameters
xinyuangui2 Aug 28, 2025
3551146
add xgboost into local mode
xinyuangui2 Aug 28, 2025
62c5f46
Merge branch 'master' into use-fnutils-in-trainer
xinyuangui2 Sep 2, 2025
f5fcaee
resolve comments
xinyuangui2 Sep 2, 2025
34390e6
Update .gitignore
xinyuangui2 Sep 2, 2025
067777a
Apply suggestions from code review
matthewdeng Sep 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ci/lint/pydoclint-baseline.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1976,7 +1976,7 @@ python/ray/train/torch/torch_predictor.py
--------------------
python/ray/train/torch/torch_trainer.py
DOC104: Method `TorchTrainer.__init__`: Arguments are the same in the docstring and the function signature, but are in a different order.
DOC105: Method `TorchTrainer.__init__`: Argument names match, but type hints in these args do not match: train_loop_per_worker, train_loop_config, torch_config, scaling_config, run_config, datasets, dataset_config, metadata, resume_from_checkpoint
DOC105: Method `TorchTrainer.__init__`: Argument names match, but type hints in these args do not match: train_loop_per_worker, train_loop_config, torch_config, scaling_config, run_config, datasets, dataset_config, metadata, resume_from_checkpoint, running_without_ray_train_controller
--------------------
python/ray/train/torch/train_loop_utils.py
DOC201: Function `get_device` does not have a return section in docstring
Expand Down
2 changes: 2 additions & 0 deletions python/ray/train/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
accelerate,
backward,
enable_reproducibility,
get_device,
get_devices,
prepare_data_loader,
prepare_model,
prepare_optimizer,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import logging
from typing import Any, Callable, Dict, Optional

from ray.data import DataIterator
from ray.train import Checkpoint, Result
from ray.train.trainer import GenDataset
from ray.train.v2._internal.execution.train_fn_utils import (
TrainFnUtils,
get_train_fn_utils,
set_train_fn_utils,
)
from ray.train.v2._internal.util import date_str
from ray.train.v2.api.context import (
TrainContext as ExternalTrainContext,
TrainContextWithoutRayTrainController,
)

logger = logging.getLogger(__name__)


class TorchWithoutRayTrainControllerFnUtils(TrainFnUtils):
"""TrainFnUtils for jobs launched without ray train controller.
This is more for testing purposes, and some functionality is missing.
"""

def __init__(
self,
experiment_name: str,
local_world_size: int,
local_rank: int,
dataset_shards: Optional[Dict[str, DataIterator]] = None,
):
self._context = TrainContextWithoutRayTrainController(
experiment_name=experiment_name,
local_world_size=local_world_size,
local_rank=local_rank,
)
self._dataset_shards = dataset_shards
self._last_metrics = None

def report(
self,
metrics: Dict[str, Any],
checkpoint: Optional[Checkpoint] = None,
checkpoint_dir_name: Optional[str] = None,
) -> None:
self._last_metrics = metrics

def get_checkpoint(self) -> Optional[Checkpoint]:
return None

def get_dataset_shard(self, dataset_name: str) -> DataIterator:
assert (
self._dataset_shards is not None and dataset_name in self._dataset_shards
), f"Dataset shard {dataset_name} not found."
return self._dataset_shards[dataset_name]

def get_context(self) -> ExternalTrainContext:
return self._context

def is_running_with_ray_train_controller(self) -> bool:
return False

def _get_last_metrics(self) -> Optional[Dict[str, Any]]:
"""return the last metrics reported by the training function.
This function should only be called by TorchBackendWithoutRayTrainController
"""
return self._last_metrics


class TorchBackendWithoutRayTrainController:
def __init__(self, datasets: Optional[Dict[str, GenDataset]] = None):
if datasets is not None:
datasets = {k: v() if callable(v) else v for k, v in datasets.items()}

self.local_world_size = 1
self.local_rank = 0

set_train_fn_utils(
TorchWithoutRayTrainControllerFnUtils(
experiment_name=self._get_experiment_name(),
local_world_size=self.local_world_size,
local_rank=self.local_rank,
dataset_shards=datasets,
)
)

def _get_experiment_name(self) -> str:
return f"train_without_ray_train_controller-{date_str()}"

def fit(self, train_func: Callable[[], None]) -> Result:
train_func()
train_fn_utils = get_train_fn_utils()
assert isinstance(train_fn_utils, TorchWithoutRayTrainControllerFnUtils)
return Result(
metrics=train_fn_utils._get_last_metrics(),
checkpoint=None,
path=None,
error=None,
)
62 changes: 52 additions & 10 deletions python/ray/train/v2/_internal/execution/train_fn_utils.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,27 @@
import threading
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional

from ray.data import DataIterator
from ray.train import Checkpoint
from ray.train.v2._internal.execution.context import (
get_train_context as get_internal_train_context,
)
from ray.train.v2.api.context import TrainContext as ExternalTrainContext
from ray.train.v2.api.context import (
TrainContext as ExternalTrainContext,
TrainContextWithRayTrainController,
)


class TrainFnUtils:
class TrainFnUtils(ABC):
"""Utility class providing an abstraction layer between user-facing APIs
and :class:`~ray.train.v2._internal.execution.context.TrainContext`.
and :class:`~ray.train.v2.api.context.TrainContext`.

It should be set before the users' training function is called, like training workers initialization.
This class can be patched if new user APIs behaviors is wanted.
"""

@abstractmethod
def report(
self,
metrics: Dict[str, Any],
Expand All @@ -33,20 +38,20 @@ def report(
be stored in the default storage path. If set, make sure
this value is unique for each iteration.
"""
return get_internal_train_context().report(
metrics, checkpoint, checkpoint_dir_name
)
pass

def get_checkpoint(self):
@abstractmethod
def get_checkpoint(self) -> Optional[Checkpoint]:
"""Get the latest checkpoint to resume training from.

Returns:
The latest checkpoint if available, None otherwise.
"""
return get_internal_train_context().get_checkpoint()
pass

@abstractmethod
def get_dataset_shard(self, dataset_name: str) -> DataIterator:
"""Get the dataset shard for this worker.
"""Get the dataset shard for this training process.

This method is used by the public API function :func:`ray.train.get_dataset_shard`.
Users should typically call ``ray.train.get_dataset_shard()`` instead of calling this method directly.
Expand All @@ -57,10 +62,47 @@ def get_dataset_shard(self, dataset_name: str) -> DataIterator:
Returns:
The DataIterator shard for this worker.
"""
pass

@abstractmethod
def get_context(self) -> ExternalTrainContext:
"""Get the TrainContext for this training process.
Different implmentation of TrainFnUtils will return different TrainContext.

Returns:
The train context for this training process.
"""
pass

@abstractmethod
def is_running_with_ray_train_controller(self) -> bool:
pass


class TrainFnUtilsWithRayTrainController(TrainFnUtils):
"""TrainFnUtils for jobs launched with ray train controller."""

def report(
self,
metrics: Dict[str, Any],
checkpoint: Optional[Checkpoint] = None,
checkpoint_dir_name: Optional[str] = None,
) -> None:
return get_internal_train_context().report(
metrics, checkpoint, checkpoint_dir_name
)

def get_checkpoint(self):
return get_internal_train_context().get_checkpoint()

def get_dataset_shard(self, dataset_name: str) -> DataIterator:
return get_internal_train_context().get_dataset_shard(dataset_name)

def get_context(self) -> ExternalTrainContext:
return ExternalTrainContext()
return TrainContextWithRayTrainController()

def is_running_with_ray_train_controller(self) -> bool:
return True


_train_fn_utils: Optional[TrainFnUtils] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
)
from ray.train.v2._internal.execution.storage import StorageContext
from ray.train.v2._internal.execution.train_fn_utils import (
TrainFnUtils,
TrainFnUtilsWithRayTrainController,
set_train_fn_utils,
)
from ray.train.v2._internal.execution.worker_group.poll import WorkerStatus
Expand Down Expand Up @@ -224,7 +224,7 @@ def init_train_context(
set_train_context(context)

# user facing train fn utils
set_train_fn_utils(TrainFnUtils())
set_train_fn_utils(TrainFnUtilsWithRayTrainController())

for callback in self._callbacks:
callback.after_init_train_context()
Loading