Skip to content

Commit 14001ab

Browse files
justinvyuAydin-ab
authored andcommitted
[train][data] Fix iter_torch_batches usage of ray.train.torch.get_device when running outside Ray Train (ray-project#57816)
Train V2 doesn't allow running `ray.train.torch.get_device` outside of a Ray Train worker spawned by a trainer.fit() call. Previously, `get_device()` returns the 0th index GPU of the `ray.get_gpu_ids()` assigned to the current process, or "cpu" if the current process wasn't assigned GPUs via `ray.remote(num_gpus=x)`. This PR introduces a utility to detect whether we're running inside a Ray Train worker process or not (in v1 and v2) and updates Ray Data's iter_torch_batches to only call `get_device()` if in a Train worker process. This introduces a slight API breakage for users who spawned a custom GPU Ray task and used `iter_torch_batches` or `get_device()`. --------- Signed-off-by: Justin Yu <justinvyu@anyscale.com> Signed-off-by: Aydin Abiar <aydin@anyscale.com>
1 parent 9613d62 commit 14001ab

File tree

7 files changed

+60
-9
lines changed

7 files changed

+60
-9
lines changed

python/ray/data/iterator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,7 @@ def iter_torch_batches(
408408
"""
409409

410410
from ray.train.torch import get_device
411+
from ray.train.utils import _in_ray_train_worker
411412

412413
if collate_fn is not None and (dtypes is not None or device != "auto"):
413414
raise ValueError(
@@ -424,7 +425,7 @@ def iter_torch_batches(
424425
if device == "auto":
425426
# Use the appropriate device for Ray Train, or falls back to CPU if
426427
# Ray Train is not being used.
427-
device = get_device()
428+
device = get_device() if _in_ray_train_worker() else "cpu"
428429

429430
from ray.air._internal.torch_utils import (
430431
move_tensors_to_device,

python/ray/train/_internal/session.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,3 +1197,8 @@ def get_storage() -> StorageContext:
11971197
without notice between minor versions.
11981198
"""
11991199
return get_session().storage
1200+
1201+
1202+
def _in_ray_train_worker() -> bool:
1203+
"""Check if the current process is a Ray Train V1 worker."""
1204+
return bool(get_session()) and get_session().world_rank is not None

python/ray/train/tests/test_data_parallel_trainer.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from ray.train.backend import Backend, BackendConfig
1515
from ray.train.data_parallel_trainer import DataParallelTrainer
1616
from ray.train.tests.util import create_dict_checkpoint, load_dict_checkpoint
17+
from ray.train.utils import _in_ray_train_worker
1718
from ray.tune.callback import Callback
1819
from ray.tune.tune_config import TuneConfig
1920
from ray.tune.tuner import Tuner
@@ -378,6 +379,16 @@ def train_func():
378379
trainer.fit()
379380

380381

382+
def test_in_ray_train_worker(ray_start_4_cpus):
383+
assert not _in_ray_train_worker()
384+
385+
def train_fn():
386+
assert _in_ray_train_worker()
387+
388+
trainer = DataParallelTrainer(train_fn)
389+
trainer.fit()
390+
391+
381392
if __name__ == "__main__":
382393
import sys
383394

python/ray/train/tests/test_iter_torch_batches_gpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ def test_custom_batch_collate_fn(
281281
# Set the device that's returned by device="auto" -> get_device()
282282
# This is used in `finalize_fn` to move the tensors to the correct device.
283283
device = torch.device(device)
284+
monkeypatch.setattr(ray.train.utils, "_in_ray_train_worker", lambda: True)
284285
monkeypatch.setattr(ray.train.torch, "get_device", lambda: device)
285286

286287
ds = ray.data.from_items(

python/ray/train/utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,20 @@ def _log_deprecation_warning(message: str):
1717
RayDeprecationWarning,
1818
stacklevel=2,
1919
)
20+
21+
22+
def _in_ray_train_worker() -> bool:
23+
from ray.train.v2._internal.constants import is_v2_enabled
24+
25+
if is_v2_enabled():
26+
from ray.train.v2._internal.util import (
27+
_in_ray_train_worker as _in_ray_train_v2_worker,
28+
)
29+
30+
return _in_ray_train_v2_worker()
31+
else:
32+
from ray.train._internal.session import (
33+
_in_ray_train_worker as _in_ray_train_v1_worker,
34+
)
35+
36+
return _in_ray_train_v1_worker()

python/ray/train/v2/_internal/util.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,3 +238,14 @@ def construct_user_exception_with_traceback(
238238
)
239239
logger.error(f"Error in training function:\n{exc_traceback_str}")
240240
return UserExceptionWithTraceback(e, traceback_str=exc_traceback_str)
241+
242+
243+
def _in_ray_train_worker() -> bool:
244+
"""Check if the current process is a Ray Train V2 worker."""
245+
from ray.train.v2._internal.execution.train_fn_utils import get_train_fn_utils
246+
247+
try:
248+
get_train_fn_utils()
249+
return True
250+
except RuntimeError:
251+
return False

python/ray/train/v2/tests/test_util.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,15 @@
11
import pytest
22

33
import ray
4+
from ray.train.utils import _in_ray_train_worker
45
from ray.train.v2._internal.util import ray_get_safe
5-
6-
7-
@pytest.fixture(scope="module")
8-
def ray_start_4_cpus():
9-
ray.init(num_cpus=4)
10-
yield
11-
ray.shutdown()
6+
from ray.train.v2.api.data_parallel_trainer import DataParallelTrainer
127

138

149
@pytest.mark.parametrize("type", ["task", "actor_task"])
1510
@pytest.mark.parametrize("failing", [True, False])
1611
@pytest.mark.parametrize("task_list", [True, False])
17-
def test_ray_get_safe(type, failing, task_list):
12+
def test_ray_get_safe(ray_start_4_cpus, type, failing, task_list):
1813
num_tasks = 4
1914

2015
if type == "task":
@@ -56,6 +51,16 @@ def f(self):
5651
assert out == 1
5752

5853

54+
def test_in_ray_train_worker(ray_start_4_cpus):
55+
assert not _in_ray_train_worker()
56+
57+
def train_fn():
58+
assert _in_ray_train_worker()
59+
60+
trainer = DataParallelTrainer(train_fn)
61+
trainer.fit()
62+
63+
5964
if __name__ == "__main__":
6065
import sys
6166

0 commit comments

Comments
 (0)