Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/ray/train/collective/collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Optional, TypeVar

from ray.train.v2._internal.execution.train_fn_utils import get_train_fn_utils
from ray.train.v2._internal.util import requires_train_worker
from ray.util.annotations import PublicAPI

T = TypeVar("T", bound=Optional[object])
Expand All @@ -11,6 +12,7 @@


@PublicAPI(stability="alpha")
@requires_train_worker()
def broadcast_from_rank_zero(data: T) -> T:
"""Broadcast small (<1kb) data from the rank 0 worker to all other workers.

Expand Down Expand Up @@ -53,6 +55,7 @@ def train_func():


@PublicAPI(stability="alpha")
@requires_train_worker()
def barrier() -> None:
"""Create a barrier across all workers.

Expand Down
10 changes: 9 additions & 1 deletion python/ray/train/v2/_internal/execution/train_fn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,10 +258,18 @@ def get_all_reported_checkpoints(self) -> List["ReportedCheckpoint"]:


def get_train_fn_utils() -> TrainFnUtils:
"""Return the Ray Train function utilities.

Returns:
The TrainFnUtils instance for the current worker.

Raises:
RuntimeError: If the Ray Train function utilities are not initialized.
"""
global _train_fn_utils
with _train_fn_utils_lock:
if _train_fn_utils is None:
raise RuntimeError("TrainFnUtils has not been initialized.")
raise RuntimeError("Ray Train function utilities not initialized.")
return _train_fn_utils


Expand Down
39 changes: 39 additions & 0 deletions python/ray/train/v2/_internal/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,3 +249,42 @@ def _in_ray_train_worker() -> bool:
return True
except RuntimeError:
return False


def requires_train_worker(raise_in_tune_session: bool = False) -> Callable:
"""Check that the caller is a Ray Train worker spawned by Ray Train,
with access to training function utilities.

Args:
raise_in_tune_session: Whether to raise a specific error message if the caller
is in a Tune session. If True, will raise a DeprecationWarning.

Returns:
A decorator that performs this check, which raises an error if the caller
is not a Ray Train worker.
"""

def _wrap(fn: Callable) -> Callable:
@functools.wraps(fn)
def _wrapped_fn(*args, **kwargs):
from ray.tune.trainable.trainable_fn_utils import _in_tune_session

if raise_in_tune_session and _in_tune_session():
raise DeprecationWarning(
f"`ray.train.{fn.__name__}` is deprecated when running in a function "
"passed to Ray Tune. Please use the equivalent `ray.tune` API instead. "
"See this issue for more context: "
"https://github.com/ray-project/ray/issues/49454"
)

if not _in_ray_train_worker():
raise RuntimeError(
f"`{fn.__name__}` cannot be used outside of a Ray Train training function. "
"You are calling this API from the driver or another non-training process. "
"These utilities are only available within a function launched by `trainer.fit()`."
)
return fn(*args, **kwargs)

return _wrapped_fn

return _wrap
36 changes: 6 additions & 30 deletions python/ray/train/v2/api/train_fn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from ray.train.v2._internal.data_integration.interfaces import DatasetShardMetadata
from ray.train.v2._internal.execution.train_fn_utils import get_train_fn_utils
from ray.train.v2._internal.util import requires_train_worker
from ray.train.v2.api.context import TrainContext
from ray.train.v2.api.report_config import CheckpointUploadMode
from ray.util.annotations import PublicAPI
Expand All @@ -13,6 +14,7 @@


@PublicAPI(stability="stable")
@requires_train_worker(raise_in_tune_session=True)
def report(
metrics: Dict[str, Any],
checkpoint: Optional["Checkpoint"] = None,
Expand Down Expand Up @@ -101,16 +103,6 @@ def train_func(config):
validate_config: Configuration passed to the validate_fn. Can contain info
like the validation dataset.
"""
from ray.tune.trainable.trainable_fn_utils import _in_tune_session

if _in_tune_session():
raise DeprecationWarning(
"`ray.train.report` is deprecated when running in a function "
"passed to Ray Tune. Please use `ray.tune.report` instead. "
"See this issue for more context: "
"https://github.com/ray-project/ray/issues/49454"
)

if delete_local_checkpoint_after_upload is None:
delete_local_checkpoint_after_upload = (
checkpoint_upload_mode._default_delete_local_checkpoint_after_upload()
Expand All @@ -133,27 +125,19 @@ def train_func(config):


@PublicAPI(stability="stable")
@requires_train_worker(raise_in_tune_session=True)
def get_context() -> TrainContext:
"""Get or create a singleton training context.

The context is only available within a function passed to Ray Train.

See the :class:`~ray.train.TrainContext` API reference to see available methods.
"""
from ray.tune.trainable.trainable_fn_utils import _in_tune_session

if _in_tune_session():
raise DeprecationWarning(
"`ray.train.get_context` is deprecated when running in a function "
"passed to Ray Tune. Please use `ray.tune.get_context` instead. "
"See this issue for more context: "
"https://github.com/ray-project/ray/issues/49454"
)

return get_train_fn_utils().get_context()


@PublicAPI(stability="stable")
@requires_train_worker(raise_in_tune_session=True)
def get_checkpoint() -> Optional["Checkpoint"]:
"""Access the latest reported checkpoint to resume from if one exists.

Expand Down Expand Up @@ -197,20 +181,11 @@ def train_func(config):
Checkpoint object if the session is currently being resumed.
Otherwise, return None.
"""
from ray.tune.trainable.trainable_fn_utils import _in_tune_session

if _in_tune_session():
raise DeprecationWarning(
"`ray.train.get_checkpoint` is deprecated when running in a function "
"passed to Ray Tune. Please use `ray.tune.get_checkpoint` instead. "
"See this issue for more context: "
"https://github.com/ray-project/ray/issues/49454"
)

return get_train_fn_utils().get_checkpoint()


@PublicAPI(stability="alpha")
@requires_train_worker()
def get_all_reported_checkpoints() -> List["ReportedCheckpoint"]:
"""Get all the reported checkpoints so far.

Expand Down Expand Up @@ -256,6 +231,7 @@ def train_func(config):


@PublicAPI(stability="stable")
@requires_train_worker()
def get_dataset_shard(dataset_name: Optional[str] = None) -> Optional["DataIterator"]:
"""Returns the :class:`ray.data.DataIterator` shard for this worker.

Expand Down
109 changes: 109 additions & 0 deletions python/ray/train/v2/torch/train_loop_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
get_devices as get_devices_distributed,
)
from ray.train.v2._internal.execution.train_fn_utils import get_train_fn_utils
from ray.train.v2._internal.util import requires_train_worker
from ray.util.annotations import Deprecated, PublicAPI

logger = logging.getLogger(__name__)
Expand All @@ -38,11 +39,119 @@
)


@PublicAPI(stability="stable")
@requires_train_worker()
def get_device() -> torch.device:
"""Gets the correct torch device configured for the current worker.

Returns the torch device for the current worker. If more than 1 GPU is
requested per worker, returns the device with the lowest device index.

.. note::

If you requested multiple GPUs per worker, and want to get
the full list of torch devices, please use
:meth:`~ray.train.torch.get_devices`.

Assumes that `CUDA_VISIBLE_DEVICES` is set and is a
superset of the `ray.get_gpu_ids()`.

Returns:
The torch device assigned to the current worker.

Examples:

Example: Launched 2 workers on the current node, each with 1 GPU

.. testcode::
:skipif: True

os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"
ray.get_gpu_ids() == [2]
torch.cuda.is_available() == True
get_device() == torch.device("cuda:0")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this supposed to be cuda:2? The other examples make it seem like it should be cuda:<get_gpu_ids()>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cuda:0 is the "logical cuda device", which points to the 0th index in the CUDA_VISIBLE_DEVICES string: CUDA_VISIBLE_DEVICES[0] == "2"


Example: Launched 4 workers on the current node, each with 1 GPU

.. testcode::
:skipif: True

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
ray.get_gpu_ids() == [2]
torch.cuda.is_available() == True
get_device() == torch.device("cuda:2")

Example: Launched 2 workers on the current node, each with 2 GPUs

.. testcode::
:skipif: True

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
ray.get_gpu_ids() == [2,3]
torch.cuda.is_available() == True
get_device() == torch.device("cuda:2")


You can move a model to device by:

.. testcode::
:skipif: True

model.to(ray.train.torch.get_device())

Instead of manually checking the device type:

.. testcode::
:skipif: True

model.to("cuda" if torch.cuda.is_available() else "cpu")
"""
return get_devices()[0]


@PublicAPI(stability="beta")
@requires_train_worker()
def get_devices() -> List[torch.device]:
"""Gets the list of torch devices configured for the current worker.

Assumes that `CUDA_VISIBLE_DEVICES` is set and is a
superset of the `ray.get_gpu_ids()`.

Returns:
The list of torch devices assigned to the current worker.

Examples:

Example: Launched 2 workers on the current node, each with 1 GPU

.. testcode::
:skipif: True

os.environ["CUDA_VISIBLE_DEVICES"] == "2,3"
ray.get_gpu_ids() == [2]
torch.cuda.is_available() == True
get_devices() == [torch.device("cuda:0")]

Example: Launched 4 workers on the current node, each with 1 GPU

.. testcode::
:skipif: True

os.environ["CUDA_VISIBLE_DEVICES"] == "0,1,2,3"
ray.get_gpu_ids() == [2]
torch.cuda.is_available() == True
get_devices() == [torch.device("cuda:2")]

Example: Launched 2 workers on the current node, each with 2 GPUs

.. testcode::
:skipif: True

os.environ["CUDA_VISIBLE_DEVICES"] == "0,1,2,3"
ray.get_gpu_ids() == [2,3]
torch.cuda.is_available() == True
get_devices() == [torch.device("cuda:2"), torch.device("cuda:3")]
"""
if get_train_fn_utils().is_distributed():
return get_devices_distributed()
else:
Expand Down
6 changes: 3 additions & 3 deletions python/ray/tune/tests/test_api_migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,19 @@ def test_trainable_fn_utils(tmp_path, monkeypatch, v2_enabled):
)

def tune_fn(config):
with asserting_context(match="ray.tune.get_checkpoint"):
with asserting_context(match="get_checkpoint"):
ray.train.get_checkpoint()

with warnings.catch_warnings():
ray.tune.get_checkpoint()

with asserting_context(match="ray.tune.get_context"):
with asserting_context(match="get_context"):
ray.train.get_context()

with warnings.catch_warnings():
ray.tune.get_context()

with asserting_context(match="ray.tune.report"):
with asserting_context(match="report"):
ray.train.report({"a": 1})

with warnings.catch_warnings():
Expand Down