Skip to content

Commit dd60663

Browse files
justinvyupeterxcli
authored andcommitted
[train] Improve error message if users call training function utils outside of a Ray Train worker (ray-project#57863)
Introduce a decorator to mark functions that require running inside a worker process spawned by Ray Train. --------- Signed-off-by: Justin Yu <justinvyu@anyscale.com> Signed-off-by: peterxcli <peterxcli@gmail.com>
1 parent 35530b8 commit dd60663

File tree

6 files changed

+169
-34
lines changed

6 files changed

+169
-34
lines changed

python/ray/train/collective/collectives.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Optional, TypeVar
33

44
from ray.train.v2._internal.execution.train_fn_utils import get_train_fn_utils
5+
from ray.train.v2._internal.util import requires_train_worker
56
from ray.util.annotations import PublicAPI
67

78
T = TypeVar("T", bound=Optional[object])
@@ -11,6 +12,7 @@
1112

1213

1314
@PublicAPI(stability="alpha")
15+
@requires_train_worker()
1416
def broadcast_from_rank_zero(data: T) -> T:
1517
"""Broadcast small (<1kb) data from the rank 0 worker to all other workers.
1618
@@ -53,6 +55,7 @@ def train_func():
5355

5456

5557
@PublicAPI(stability="alpha")
58+
@requires_train_worker()
5659
def barrier() -> None:
5760
"""Create a barrier across all workers.
5861

python/ray/train/v2/_internal/execution/train_fn_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,10 +258,18 @@ def get_all_reported_checkpoints(self) -> List["ReportedCheckpoint"]:
258258

259259

260260
def get_train_fn_utils() -> TrainFnUtils:
261+
"""Return the Ray Train function utilities.
262+
263+
Returns:
264+
The TrainFnUtils instance for the current worker.
265+
266+
Raises:
267+
RuntimeError: If the Ray Train function utilities are not initialized.
268+
"""
261269
global _train_fn_utils
262270
with _train_fn_utils_lock:
263271
if _train_fn_utils is None:
264-
raise RuntimeError("TrainFnUtils has not been initialized.")
272+
raise RuntimeError("Ray Train function utilities not initialized.")
265273
return _train_fn_utils
266274

267275

python/ray/train/v2/_internal/util.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,3 +249,42 @@ def _in_ray_train_worker() -> bool:
249249
return True
250250
except RuntimeError:
251251
return False
252+
253+
254+
def requires_train_worker(raise_in_tune_session: bool = False) -> Callable:
255+
"""Check that the caller is a Ray Train worker spawned by Ray Train,
256+
with access to training function utilities.
257+
258+
Args:
259+
raise_in_tune_session: Whether to raise a specific error message if the caller
260+
is in a Tune session. If True, will raise a DeprecationWarning.
261+
262+
Returns:
263+
A decorator that performs this check, which raises an error if the caller
264+
is not a Ray Train worker.
265+
"""
266+
267+
def _wrap(fn: Callable) -> Callable:
268+
@functools.wraps(fn)
269+
def _wrapped_fn(*args, **kwargs):
270+
from ray.tune.trainable.trainable_fn_utils import _in_tune_session
271+
272+
if raise_in_tune_session and _in_tune_session():
273+
raise DeprecationWarning(
274+
f"`ray.train.{fn.__name__}` is deprecated when running in a function "
275+
"passed to Ray Tune. Please use the equivalent `ray.tune` API instead. "
276+
"See this issue for more context: "
277+
"https://github.com/ray-project/ray/issues/49454"
278+
)
279+
280+
if not _in_ray_train_worker():
281+
raise RuntimeError(
282+
f"`{fn.__name__}` cannot be used outside of a Ray Train training function. "
283+
"You are calling this API from the driver or another non-training process. "
284+
"These utilities are only available within a function launched by `trainer.fit()`."
285+
)
286+
return fn(*args, **kwargs)
287+
288+
return _wrapped_fn
289+
290+
return _wrap

python/ray/train/v2/api/train_fn_utils.py

Lines changed: 6 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from ray.train.v2._internal.data_integration.interfaces import DatasetShardMetadata
44
from ray.train.v2._internal.execution.train_fn_utils import get_train_fn_utils
5+
from ray.train.v2._internal.util import requires_train_worker
56
from ray.train.v2.api.context import TrainContext
67
from ray.train.v2.api.report_config import CheckpointUploadMode
78
from ray.util.annotations import PublicAPI
@@ -13,6 +14,7 @@
1314

1415

1516
@PublicAPI(stability="stable")
17+
@requires_train_worker(raise_in_tune_session=True)
1618
def report(
1719
metrics: Dict[str, Any],
1820
checkpoint: Optional["Checkpoint"] = None,
@@ -101,16 +103,6 @@ def train_func(config):
101103
validate_config: Configuration passed to the validate_fn. Can contain info
102104
like the validation dataset.
103105
"""
104-
from ray.tune.trainable.trainable_fn_utils import _in_tune_session
105-
106-
if _in_tune_session():
107-
raise DeprecationWarning(
108-
"`ray.train.report` is deprecated when running in a function "
109-
"passed to Ray Tune. Please use `ray.tune.report` instead. "
110-
"See this issue for more context: "
111-
"https://github.com/ray-project/ray/issues/49454"
112-
)
113-
114106
if delete_local_checkpoint_after_upload is None:
115107
delete_local_checkpoint_after_upload = (
116108
checkpoint_upload_mode._default_delete_local_checkpoint_after_upload()
@@ -133,27 +125,19 @@ def train_func(config):
133125

134126

135127
@PublicAPI(stability="stable")
128+
@requires_train_worker(raise_in_tune_session=True)
136129
def get_context() -> TrainContext:
137130
"""Get or create a singleton training context.
138131
139132
The context is only available within a function passed to Ray Train.
140133
141134
See the :class:`~ray.train.TrainContext` API reference to see available methods.
142135
"""
143-
from ray.tune.trainable.trainable_fn_utils import _in_tune_session
144-
145-
if _in_tune_session():
146-
raise DeprecationWarning(
147-
"`ray.train.get_context` is deprecated when running in a function "
148-
"passed to Ray Tune. Please use `ray.tune.get_context` instead. "
149-
"See this issue for more context: "
150-
"https://github.com/ray-project/ray/issues/49454"
151-
)
152-
153136
return get_train_fn_utils().get_context()
154137

155138

156139
@PublicAPI(stability="stable")
140+
@requires_train_worker(raise_in_tune_session=True)
157141
def get_checkpoint() -> Optional["Checkpoint"]:
158142
"""Access the latest reported checkpoint to resume from if one exists.
159143
@@ -197,20 +181,11 @@ def train_func(config):
197181
Checkpoint object if the session is currently being resumed.
198182
Otherwise, return None.
199183
"""
200-
from ray.tune.trainable.trainable_fn_utils import _in_tune_session
201-
202-
if _in_tune_session():
203-
raise DeprecationWarning(
204-
"`ray.train.get_checkpoint` is deprecated when running in a function "
205-
"passed to Ray Tune. Please use `ray.tune.get_checkpoint` instead. "
206-
"See this issue for more context: "
207-
"https://github.com/ray-project/ray/issues/49454"
208-
)
209-
210184
return get_train_fn_utils().get_checkpoint()
211185

212186

213187
@PublicAPI(stability="alpha")
188+
@requires_train_worker()
214189
def get_all_reported_checkpoints() -> List["ReportedCheckpoint"]:
215190
"""Get all the reported checkpoints so far.
216191
@@ -256,6 +231,7 @@ def train_func(config):
256231

257232

258233
@PublicAPI(stability="stable")
234+
@requires_train_worker()
259235
def get_dataset_shard(dataset_name: Optional[str] = None) -> Optional["DataIterator"]:
260236
"""Returns the :class:`ray.data.DataIterator` shard for this worker.
261237

python/ray/train/v2/torch/train_loop_utils.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
get_devices as get_devices_distributed,
2323
)
2424
from ray.train.v2._internal.execution.train_fn_utils import get_train_fn_utils
25+
from ray.train.v2._internal.util import requires_train_worker
2526
from ray.util.annotations import Deprecated, PublicAPI
2627

2728
logger = logging.getLogger(__name__)
@@ -38,11 +39,119 @@
3839
)
3940

4041

42+
@PublicAPI(stability="stable")
43+
@requires_train_worker()
4144
def get_device() -> torch.device:
45+
"""Gets the correct torch device configured for the current worker.
46+
47+
Returns the torch device for the current worker. If more than 1 GPU is
48+
requested per worker, returns the device with the lowest device index.
49+
50+
.. note::
51+
52+
If you requested multiple GPUs per worker, and want to get
53+
the full list of torch devices, please use
54+
:meth:`~ray.train.torch.get_devices`.
55+
56+
Assumes that `CUDA_VISIBLE_DEVICES` is set and is a
57+
superset of the `ray.get_gpu_ids()`.
58+
59+
Returns:
60+
The torch device assigned to the current worker.
61+
62+
Examples:
63+
64+
Example: Launched 2 workers on the current node, each with 1 GPU
65+
66+
.. testcode::
67+
:skipif: True
68+
69+
os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"
70+
ray.get_gpu_ids() == [2]
71+
torch.cuda.is_available() == True
72+
get_device() == torch.device("cuda:0")
73+
74+
Example: Launched 4 workers on the current node, each with 1 GPU
75+
76+
.. testcode::
77+
:skipif: True
78+
79+
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
80+
ray.get_gpu_ids() == [2]
81+
torch.cuda.is_available() == True
82+
get_device() == torch.device("cuda:2")
83+
84+
Example: Launched 2 workers on the current node, each with 2 GPUs
85+
86+
.. testcode::
87+
:skipif: True
88+
89+
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
90+
ray.get_gpu_ids() == [2,3]
91+
torch.cuda.is_available() == True
92+
get_device() == torch.device("cuda:2")
93+
94+
95+
You can move a model to device by:
96+
97+
.. testcode::
98+
:skipif: True
99+
100+
model.to(ray.train.torch.get_device())
101+
102+
Instead of manually checking the device type:
103+
104+
.. testcode::
105+
:skipif: True
106+
107+
model.to("cuda" if torch.cuda.is_available() else "cpu")
108+
"""
42109
return get_devices()[0]
43110

44111

112+
@PublicAPI(stability="beta")
113+
@requires_train_worker()
45114
def get_devices() -> List[torch.device]:
115+
"""Gets the list of torch devices configured for the current worker.
116+
117+
Assumes that `CUDA_VISIBLE_DEVICES` is set and is a
118+
superset of the `ray.get_gpu_ids()`.
119+
120+
Returns:
121+
The list of torch devices assigned to the current worker.
122+
123+
Examples:
124+
125+
Example: Launched 2 workers on the current node, each with 1 GPU
126+
127+
.. testcode::
128+
:skipif: True
129+
130+
os.environ["CUDA_VISIBLE_DEVICES"] == "2,3"
131+
ray.get_gpu_ids() == [2]
132+
torch.cuda.is_available() == True
133+
get_devices() == [torch.device("cuda:0")]
134+
135+
Example: Launched 4 workers on the current node, each with 1 GPU
136+
137+
.. testcode::
138+
:skipif: True
139+
140+
os.environ["CUDA_VISIBLE_DEVICES"] == "0,1,2,3"
141+
ray.get_gpu_ids() == [2]
142+
torch.cuda.is_available() == True
143+
get_devices() == [torch.device("cuda:2")]
144+
145+
Example: Launched 2 workers on the current node, each with 2 GPUs
146+
147+
.. testcode::
148+
:skipif: True
149+
150+
os.environ["CUDA_VISIBLE_DEVICES"] == "0,1,2,3"
151+
ray.get_gpu_ids() == [2,3]
152+
torch.cuda.is_available() == True
153+
get_devices() == [torch.device("cuda:2"), torch.device("cuda:3")]
154+
"""
46155
if get_train_fn_utils().is_distributed():
47156
return get_devices_distributed()
48157
else:

python/ray/tune/tests/test_api_migrations.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,19 +40,19 @@ def test_trainable_fn_utils(tmp_path, monkeypatch, v2_enabled):
4040
)
4141

4242
def tune_fn(config):
43-
with asserting_context(match="ray.tune.get_checkpoint"):
43+
with asserting_context(match="get_checkpoint"):
4444
ray.train.get_checkpoint()
4545

4646
with warnings.catch_warnings():
4747
ray.tune.get_checkpoint()
4848

49-
with asserting_context(match="ray.tune.get_context"):
49+
with asserting_context(match="get_context"):
5050
ray.train.get_context()
5151

5252
with warnings.catch_warnings():
5353
ray.tune.get_context()
5454

55-
with asserting_context(match="ray.tune.report"):
55+
with asserting_context(match="report"):
5656
ray.train.report({"a": 1})
5757

5858
with warnings.catch_warnings():

0 commit comments

Comments
 (0)