[train][data] Fix iter_torch_batches usage of ray.train.torch.get_device when running outside Ray Train#57816
Conversation
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
There was a problem hiding this comment.
Code Review
This pull request introduces a utility _in_ray_train_worker to correctly detect whether code is running within a Ray Train worker, for both V1 and V2. This utility is then used in iter_torch_batches to conditionally call ray.train.torch.get_device, fixing an issue where get_device would fail when used outside a Ray Train context with Train V2. The changes are well-structured, with separate logic for V1 and V2, and are accompanied by good test coverage. My feedback includes a couple of minor suggestions to improve code style and maintainability.
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
| # Use the appropriate device for Ray Train, or falls back to CPU if | ||
| # Ray Train is not being used. | ||
| device = get_device() | ||
| device = get_device() if _in_ray_train_worker() else "cpu" |
There was a problem hiding this comment.
nit: Would it be possible roll this condition into get_device?
There was a problem hiding this comment.
I considered this, but our current stance is that we don't want people calling get_device() (Train worker utils) outside of Train workers (it will error in V2), so I fully gate the call to this method.
…vice` when running outside Ray Train (ray-project#57816) Train V2 doesn't allow running `ray.train.torch.get_device` outside of a Ray Train worker spawned by a trainer.fit() call. Previously, `get_device()` returns the 0th index GPU of the `ray.get_gpu_ids()` assigned to the current process, or "cpu" if the current process wasn't assigned GPUs via `ray.remote(num_gpus=x)`. This PR introduces a utility to detect whether we're running inside a Ray Train worker process or not (in v1 and v2) and updates Ray Data's iter_torch_batches to only call `get_device()` if in a Train worker process. This introduces a slight API breakage for users who spawned a custom GPU Ray task and used `iter_torch_batches` or `get_device()`. --------- Signed-off-by: Justin Yu <justinvyu@anyscale.com>
…vice` when running outside Ray Train (ray-project#57816) Train V2 doesn't allow running `ray.train.torch.get_device` outside of a Ray Train worker spawned by a trainer.fit() call. Previously, `get_device()` returns the 0th index GPU of the `ray.get_gpu_ids()` assigned to the current process, or "cpu" if the current process wasn't assigned GPUs via `ray.remote(num_gpus=x)`. This PR introduces a utility to detect whether we're running inside a Ray Train worker process or not (in v1 and v2) and updates Ray Data's iter_torch_batches to only call `get_device()` if in a Train worker process. This introduces a slight API breakage for users who spawned a custom GPU Ray task and used `iter_torch_batches` or `get_device()`. --------- Signed-off-by: Justin Yu <justinvyu@anyscale.com> Signed-off-by: xgui <xgui@anyscale.com>
…vice` when running outside Ray Train (#57816) Train V2 doesn't allow running `ray.train.torch.get_device` outside of a Ray Train worker spawned by a trainer.fit() call. Previously, `get_device()` returns the 0th index GPU of the `ray.get_gpu_ids()` assigned to the current process, or "cpu" if the current process wasn't assigned GPUs via `ray.remote(num_gpus=x)`. This PR introduces a utility to detect whether we're running inside a Ray Train worker process or not (in v1 and v2) and updates Ray Data's iter_torch_batches to only call `get_device()` if in a Train worker process. This introduces a slight API breakage for users who spawned a custom GPU Ray task and used `iter_torch_batches` or `get_device()`. --------- Signed-off-by: Justin Yu <justinvyu@anyscale.com> Signed-off-by: elliot-barn <elliot.barnwell@anyscale.com>
…vice` when running outside Ray Train (ray-project#57816) Train V2 doesn't allow running `ray.train.torch.get_device` outside of a Ray Train worker spawned by a trainer.fit() call. Previously, `get_device()` returns the 0th index GPU of the `ray.get_gpu_ids()` assigned to the current process, or "cpu" if the current process wasn't assigned GPUs via `ray.remote(num_gpus=x)`. This PR introduces a utility to detect whether we're running inside a Ray Train worker process or not (in v1 and v2) and updates Ray Data's iter_torch_batches to only call `get_device()` if in a Train worker process. This introduces a slight API breakage for users who spawned a custom GPU Ray task and used `iter_torch_batches` or `get_device()`. --------- Signed-off-by: Justin Yu <justinvyu@anyscale.com>
…vice` when running outside Ray Train (ray-project#57816) Train V2 doesn't allow running `ray.train.torch.get_device` outside of a Ray Train worker spawned by a trainer.fit() call. Previously, `get_device()` returns the 0th index GPU of the `ray.get_gpu_ids()` assigned to the current process, or "cpu" if the current process wasn't assigned GPUs via `ray.remote(num_gpus=x)`. This PR introduces a utility to detect whether we're running inside a Ray Train worker process or not (in v1 and v2) and updates Ray Data's iter_torch_batches to only call `get_device()` if in a Train worker process. This introduces a slight API breakage for users who spawned a custom GPU Ray task and used `iter_torch_batches` or `get_device()`. --------- Signed-off-by: Justin Yu <justinvyu@anyscale.com> Signed-off-by: Aydin Abiar <aydin@anyscale.com>
…vice` when running outside Ray Train (ray-project#57816) Train V2 doesn't allow running `ray.train.torch.get_device` outside of a Ray Train worker spawned by a trainer.fit() call. Previously, `get_device()` returns the 0th index GPU of the `ray.get_gpu_ids()` assigned to the current process, or "cpu" if the current process wasn't assigned GPUs via `ray.remote(num_gpus=x)`. This PR introduces a utility to detect whether we're running inside a Ray Train worker process or not (in v1 and v2) and updates Ray Data's iter_torch_batches to only call `get_device()` if in a Train worker process. This introduces a slight API breakage for users who spawned a custom GPU Ray task and used `iter_torch_batches` or `get_device()`. --------- Signed-off-by: Justin Yu <justinvyu@anyscale.com> Signed-off-by: Future-Outlier <eric901201@gmail.com>
Description
Train V2 doesn't allow running
ray.train.torch.get_deviceoutside of a Ray Train worker spawned by a trainer.fit() call. Previously,get_device()returns the 0th index GPU of theray.get_gpu_ids()assigned to the current process, or "cpu" if the current process wasn't assigned GPUs viaray.remote(num_gpus=x).This PR introduces a utility to detect whether we're running inside a Ray Train worker process or not (in v1 and v2) and updates Ray Data's iter_torch_batches to only call
get_device()if in a Train worker process.This introduces a slight API breakage for users who spawned a custom GPU Ray task and used
iter_torch_batchesorget_device().