Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/ray/train/collective/collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Optional, TypeVar

from ray.train.v2._internal.execution.train_fn_utils import get_train_fn_utils
from ray.train.v2._internal.util import requires_train_worker
from ray.util.annotations import PublicAPI

T = TypeVar("T", bound=Optional[object])
Expand All @@ -11,6 +12,7 @@


@PublicAPI(stability="alpha")
@requires_train_worker()
def broadcast_from_rank_zero(data: T) -> T:
"""Broadcast small (<1kb) data from the rank 0 worker to all other workers.

Expand Down Expand Up @@ -53,6 +55,7 @@ def train_func():


@PublicAPI(stability="alpha")
@requires_train_worker()
def barrier() -> None:
"""Create a barrier across all workers.

Expand Down
10 changes: 9 additions & 1 deletion python/ray/train/v2/_internal/execution/train_fn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,10 +258,18 @@ def get_all_reported_checkpoints(self) -> List["ReportedCheckpoint"]:


def get_train_fn_utils() -> TrainFnUtils:
"""Return the Ray Train function utilities.

Returns:
The TrainFnUtils instance for the current worker.

Raises:
RuntimeError: If the Ray Train function utilities are not initialized.
"""
global _train_fn_utils
with _train_fn_utils_lock:
if _train_fn_utils is None:
raise RuntimeError("TrainFnUtils has not been initialized.")
raise RuntimeError("Ray Train function utilities not initialized.")
return _train_fn_utils


Expand Down
39 changes: 39 additions & 0 deletions python/ray/train/v2/_internal/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,3 +249,42 @@ def _in_ray_train_worker() -> bool:
return True
except RuntimeError:
return False


def requires_train_worker(raise_in_tune_session: bool = False) -> Callable:
"""Check that the caller is a Ray Train worker spawned by Ray Train,
with access to training function utilities.

Args:
raise_in_tune_session: Whether to raise a specific error message if the caller
is in a Tune session. If True, will raise a DeprecationWarning.

Returns:
A decorator that performs this check, which raises an error if the caller
is not a Ray Train worker.
"""

def _wrap(fn: Callable) -> Callable:
@functools.wraps(fn)
def _wrapped_fn(*args, **kwargs):
from ray.tune.trainable.trainable_fn_utils import _in_tune_session

if raise_in_tune_session and _in_tune_session():
raise DeprecationWarning(
f"`ray.train.{fn.__name__}` is deprecated when running in a function "
"passed to Ray Tune. Please use the equivalent `ray.tune` API instead. "
"See this issue for more context: "
"https://github.com/ray-project/ray/issues/49454"
)

if not _in_ray_train_worker():
raise RuntimeError(
f"`{fn.__name__}` cannot be used outside of a Ray Train training function. "
"You are calling this API from the driver or another non-training process. "
"These utilities are only available within a function launched by `trainer.fit()`."
)
return fn(*args, **kwargs)

return _wrapped_fn

return _wrap
16 changes: 6 additions & 10 deletions python/ray/train/v2/api/train_fn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from ray.train.v2._internal.data_integration.interfaces import DatasetShardMetadata
from ray.train.v2._internal.execution.train_fn_utils import get_train_fn_utils
from ray.train.v2._internal.util import requires_train_worker
from ray.train.v2.api.context import TrainContext
from ray.train.v2.api.report_config import CheckpointUploadMode
from ray.util.annotations import PublicAPI
Expand All @@ -13,6 +14,7 @@


@PublicAPI(stability="stable")
@requires_train_worker(raise_in_tune_session=True)
def report(
metrics: Dict[str, Any],
checkpoint: Optional["Checkpoint"] = None,
Expand Down Expand Up @@ -101,16 +103,6 @@ def train_func(config):
validate_config: Configuration passed to the validate_fn. Can contain info
like the validation dataset.
"""
from ray.tune.trainable.trainable_fn_utils import _in_tune_session

if _in_tune_session():
raise DeprecationWarning(
"`ray.train.report` is deprecated when running in a function "
"passed to Ray Tune. Please use `ray.tune.report` instead. "
"See this issue for more context: "
"https://github.com/ray-project/ray/issues/49454"
)

if delete_local_checkpoint_after_upload is None:
delete_local_checkpoint_after_upload = (
checkpoint_upload_mode._default_delete_local_checkpoint_after_upload()
Expand All @@ -133,6 +125,7 @@ def train_func(config):


@PublicAPI(stability="stable")
@requires_train_worker(raise_in_tune_session=True)
def get_context() -> TrainContext:
"""Get or create a singleton training context.

Expand All @@ -154,6 +147,7 @@ def get_context() -> TrainContext:


@PublicAPI(stability="stable")
@requires_train_worker(raise_in_tune_session=True)
def get_checkpoint() -> Optional["Checkpoint"]:
"""Access the latest reported checkpoint to resume from if one exists.

Expand Down Expand Up @@ -211,6 +205,7 @@ def train_func(config):


@PublicAPI(stability="alpha")
@requires_train_worker()
def get_all_reported_checkpoints() -> List["ReportedCheckpoint"]:
"""Get all the reported checkpoints so far.

Expand Down Expand Up @@ -256,6 +251,7 @@ def train_func(config):


@PublicAPI(stability="stable")
@requires_train_worker()
def get_dataset_shard(dataset_name: Optional[str] = None) -> Optional["DataIterator"]:
"""Returns the :class:`ray.data.DataIterator` shard for this worker.

Expand Down
109 changes: 109 additions & 0 deletions python/ray/train/v2/torch/train_loop_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
get_devices as get_devices_distributed,
)
from ray.train.v2._internal.execution.train_fn_utils import get_train_fn_utils
from ray.train.v2._internal.util import requires_train_worker
from ray.util.annotations import Deprecated, PublicAPI

logger = logging.getLogger(__name__)
Expand All @@ -38,11 +39,119 @@
)


@PublicAPI(stability="stable")
@requires_train_worker()
def get_device() -> torch.device:
"""Gets the correct torch device configured for the current worker.

Returns the torch device for the current worker. If more than 1 GPU is
requested per worker, returns the device with the lowest device index.

.. note::

If you requested multiple GPUs per worker, and want to get
the full list of torch devices, please use
:meth:`~ray.train.torch.get_devices`.

Assumes that `CUDA_VISIBLE_DEVICES` is set and is a
superset of the `ray.get_gpu_ids()`.

Returns:
The torch device assigned to the current worker.

Examples:

Example: Launched 2 workers on the current node, each with 1 GPU

.. testcode::
:skipif: True

os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"
ray.get_gpu_ids() == [2]
torch.cuda.is_available() == True
get_device() == torch.device("cuda:0")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this supposed to be cuda:2? The other examples make it seem like it should be cuda:<get_gpu_ids()>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cuda:0 is the "logical cuda device", which points to the 0th index in the CUDA_VISIBLE_DEVICES string: CUDA_VISIBLE_DEVICES[0] == "2"


Example: Launched 4 workers on the current node, each with 1 GPU

.. testcode::
:skipif: True

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
ray.get_gpu_ids() == [2]
torch.cuda.is_available() == True
get_device() == torch.device("cuda:2")

Example: Launched 2 workers on the current node, each with 2 GPUs

.. testcode::
:skipif: True

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
ray.get_gpu_ids() == [2,3]
torch.cuda.is_available() == True
get_device() == torch.device("cuda:2")


You can move a model to device by:

.. testcode::
:skipif: True

model.to(ray.train.torch.get_device())

Instead of manually checking the device type:

.. testcode::
:skipif: True

model.to("cuda" if torch.cuda.is_available() else "cpu")
"""
return get_devices()[0]


@PublicAPI(stability="beta")
@requires_train_worker()
def get_devices() -> List[torch.device]:
"""Gets the correct torch device list configured for the current worker.

Assumes that `CUDA_VISIBLE_DEVICES` is set and is a
superset of the `ray.get_gpu_ids()`.

Returns:
The list of torch devices assigned to the current worker.

Examples:

Example: Launched 2 workers on the current node, each with 1 GPU

.. testcode::
:skipif: True

os.environ["CUDA_VISIBLE_DEVICES"] == "2,3"
ray.get_gpu_ids() == [2]
torch.cuda.is_available() == True
get_devices() == [torch.device("cuda:0")]

Example: Launched 4 workers on the current node, each with 1 GPU

.. testcode::
:skipif: True

os.environ["CUDA_VISIBLE_DEVICES"] == "0,1,2,3"
ray.get_gpu_ids() == [2]
torch.cuda.is_available() == True
get_devices() == [torch.device("cuda:2")]

Example: Launched 2 workers on the current node, each with 2 GPUs

.. testcode::
:skipif: True

os.environ["CUDA_VISIBLE_DEVICES"] == "0,1,2,3"
ray.get_gpu_ids() == [2,3]
torch.cuda.is_available() == True
get_devices() == [torch.device("cuda:2"), torch.device("cuda:3")]
"""
if get_train_fn_utils().is_distributed():
return get_devices_distributed()
else:
Expand Down
6 changes: 3 additions & 3 deletions python/ray/tune/tests/test_api_migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,19 @@ def test_trainable_fn_utils(tmp_path, monkeypatch, v2_enabled):
)

def tune_fn(config):
with asserting_context(match="ray.tune.get_checkpoint"):
with asserting_context(match="get_checkpoint"):
ray.train.get_checkpoint()

with warnings.catch_warnings():
ray.tune.get_checkpoint()

with asserting_context(match="ray.tune.get_context"):
with asserting_context(match="get_context"):
ray.train.get_context()

with warnings.catch_warnings():
ray.tune.get_context()

with asserting_context(match="ray.tune.report"):
with asserting_context(match="report"):
ray.train.report({"a": 1})

with warnings.catch_warnings():
Expand Down