Skip to content

[train][data] Fix iter_torch_batches usage of ray.train.torch.get_device when running outside Ray Train#57816

Merged
justinvyu merged 6 commits intoray-project:masterfrom
justinvyu:fix_data_get_device
Oct 17, 2025
Merged

[train][data] Fix iter_torch_batches usage of ray.train.torch.get_device when running outside Ray Train#57816
justinvyu merged 6 commits intoray-project:masterfrom
justinvyu:fix_data_get_device

Conversation

@justinvyu
Copy link
Contributor

Description

Train V2 doesn't allow running ray.train.torch.get_device outside of a Ray Train worker spawned by a trainer.fit() call. Previously, get_device() returns the 0th index GPU of the ray.get_gpu_ids() assigned to the current process, or "cpu" if the current process wasn't assigned GPUs via ray.remote(num_gpus=x).

This PR introduces a utility to detect whether we're running inside a Ray Train worker process or not (in v1 and v2) and updates Ray Data's iter_torch_batches to only call get_device() if in a Train worker process.

This introduces a slight API breakage for users who spawned a custom GPU Ray task and used iter_torch_batches or get_device().

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a utility _in_ray_train_worker to correctly detect whether code is running within a Ray Train worker, for both V1 and V2. This utility is then used in iter_torch_batches to conditionally call ray.train.torch.get_device, fixing an issue where get_device would fail when used outside a Ray Train context with Train V2. The changes are well-structured, with separate logic for V1 and V2, and are accompanied by good test coverage. My feedback includes a couple of minor suggestions to improve code style and maintainability.

@ray-gardener ray-gardener bot added train Ray Train Related Issue data Ray Data-related issues labels Oct 17, 2025
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
@justinvyu justinvyu added the go add ONLY when ready to merge, run all tests label Oct 17, 2025
# Use the appropriate device for Ray Train, or falls back to CPU if
# Ray Train is not being used.
device = get_device()
device = get_device() if _in_ray_train_worker() else "cpu"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Would it be possible roll this condition into get_device?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I considered this, but our current stance is that we don't want people calling get_device() (Train worker utils) outside of Train workers (it will error in V2), so I fully gate the call to this method.

@justinvyu justinvyu enabled auto-merge (squash) October 17, 2025 20:11
@justinvyu justinvyu merged commit aff37bf into ray-project:master Oct 17, 2025
8 checks passed
@justinvyu justinvyu deleted the fix_data_get_device branch October 17, 2025 20:15
justinyeh1995 pushed a commit to justinyeh1995/ray that referenced this pull request Oct 20, 2025
…vice` when running outside Ray Train (ray-project#57816)

Train V2 doesn't allow running `ray.train.torch.get_device` outside of a
Ray Train worker spawned by a trainer.fit() call. Previously,
`get_device()` returns the 0th index GPU of the `ray.get_gpu_ids()`
assigned to the current process, or "cpu" if the current process wasn't
assigned GPUs via `ray.remote(num_gpus=x)`.

This PR introduces a utility to detect whether we're running inside a
Ray Train worker process or not (in v1 and v2) and updates Ray Data's
iter_torch_batches to only call `get_device()` if in a Train worker
process.

This introduces a slight API breakage for users who spawned a custom GPU
Ray task and used `iter_torch_batches` or `get_device()`.

---------

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
xinyuangui2 pushed a commit to xinyuangui2/ray that referenced this pull request Oct 22, 2025
…vice` when running outside Ray Train (ray-project#57816)

Train V2 doesn't allow running `ray.train.torch.get_device` outside of a
Ray Train worker spawned by a trainer.fit() call. Previously,
`get_device()` returns the 0th index GPU of the `ray.get_gpu_ids()`
assigned to the current process, or "cpu" if the current process wasn't
assigned GPUs via `ray.remote(num_gpus=x)`.

This PR introduces a utility to detect whether we're running inside a
Ray Train worker process or not (in v1 and v2) and updates Ray Data's
iter_torch_batches to only call `get_device()` if in a Train worker
process.

This introduces a slight API breakage for users who spawned a custom GPU
Ray task and used `iter_torch_batches` or `get_device()`.

---------

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: xgui <xgui@anyscale.com>
elliot-barn pushed a commit that referenced this pull request Oct 23, 2025
…vice` when running outside Ray Train (#57816)

Train V2 doesn't allow running `ray.train.torch.get_device` outside of a
Ray Train worker spawned by a trainer.fit() call. Previously,
`get_device()` returns the 0th index GPU of the `ray.get_gpu_ids()`
assigned to the current process, or "cpu" if the current process wasn't
assigned GPUs via `ray.remote(num_gpus=x)`.

This PR introduces a utility to detect whether we're running inside a
Ray Train worker process or not (in v1 and v2) and updates Ray Data's
iter_torch_batches to only call `get_device()` if in a Train worker
process.

This introduces a slight API breakage for users who spawned a custom GPU
Ray task and used `iter_torch_batches` or `get_device()`.

---------

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: elliot-barn <elliot.barnwell@anyscale.com>
landscapepainter pushed a commit to landscapepainter/ray that referenced this pull request Nov 17, 2025
…vice` when running outside Ray Train (ray-project#57816)

Train V2 doesn't allow running `ray.train.torch.get_device` outside of a
Ray Train worker spawned by a trainer.fit() call. Previously,
`get_device()` returns the 0th index GPU of the `ray.get_gpu_ids()`
assigned to the current process, or "cpu" if the current process wasn't
assigned GPUs via `ray.remote(num_gpus=x)`.

This PR introduces a utility to detect whether we're running inside a
Ray Train worker process or not (in v1 and v2) and updates Ray Data's
iter_torch_batches to only call `get_device()` if in a Train worker
process.

This introduces a slight API breakage for users who spawned a custom GPU
Ray task and used `iter_torch_batches` or `get_device()`.

---------

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Aydin-ab pushed a commit to Aydin-ab/ray-aydin that referenced this pull request Nov 19, 2025
…vice` when running outside Ray Train (ray-project#57816)

Train V2 doesn't allow running `ray.train.torch.get_device` outside of a
Ray Train worker spawned by a trainer.fit() call. Previously,
`get_device()` returns the 0th index GPU of the `ray.get_gpu_ids()`
assigned to the current process, or "cpu" if the current process wasn't
assigned GPUs via `ray.remote(num_gpus=x)`.

This PR introduces a utility to detect whether we're running inside a
Ray Train worker process or not (in v1 and v2) and updates Ray Data's
iter_torch_batches to only call `get_device()` if in a Train worker
process.

This introduces a slight API breakage for users who spawned a custom GPU
Ray task and used `iter_torch_batches` or `get_device()`.

---------

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Aydin Abiar <aydin@anyscale.com>
Future-Outlier pushed a commit to Future-Outlier/ray that referenced this pull request Dec 7, 2025
…vice` when running outside Ray Train (ray-project#57816)

Train V2 doesn't allow running `ray.train.torch.get_device` outside of a
Ray Train worker spawned by a trainer.fit() call. Previously,
`get_device()` returns the 0th index GPU of the `ray.get_gpu_ids()`
assigned to the current process, or "cpu" if the current process wasn't
assigned GPUs via `ray.remote(num_gpus=x)`.

This PR introduces a utility to detect whether we're running inside a
Ray Train worker process or not (in v1 and v2) and updates Ray Data's
iter_torch_batches to only call `get_device()` if in a Train worker
process.

This introduces a slight API breakage for users who spawned a custom GPU
Ray task and used `iter_torch_batches` or `get_device()`.

---------

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Future-Outlier <eric901201@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

data Ray Data-related issues go add ONLY when ready to merge, run all tests train Ray Train Related Issue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants