-
Notifications
You must be signed in to change notification settings - Fork 7.3k
[train] Improve error message if users call training function utils outside of a Ray Train worker #57863
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[train] Improve error message if users call training function utils outside of a Ray Train worker #57863
Changes from 2 commits
884a054
9229cfe
aae6ea4
4949567
119f095
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,6 +22,7 @@ | |
| get_devices as get_devices_distributed, | ||
| ) | ||
| from ray.train.v2._internal.execution.train_fn_utils import get_train_fn_utils | ||
| from ray.train.v2._internal.util import requires_train_worker | ||
| from ray.util.annotations import Deprecated, PublicAPI | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
@@ -38,11 +39,119 @@ | |
| ) | ||
|
|
||
|
|
||
| @PublicAPI(stability="stable") | ||
| @requires_train_worker() | ||
| def get_device() -> torch.device: | ||
| """Gets the correct torch device configured for the current worker. | ||
|
|
||
| Returns the torch device for the current worker. If more than 1 GPU is | ||
| requested per worker, returns the device with the lowest device index. | ||
|
|
||
| .. note:: | ||
|
|
||
| If you requested multiple GPUs per worker, and want to get | ||
| the full list of torch devices, please use | ||
| :meth:`~ray.train.torch.get_devices`. | ||
|
|
||
| Assumes that `CUDA_VISIBLE_DEVICES` is set and is a | ||
| superset of the `ray.get_gpu_ids()`. | ||
|
|
||
| Returns: | ||
| The torch device assigned to the current worker. | ||
|
|
||
| Examples: | ||
|
|
||
| Example: Launched 2 workers on the current node, each with 1 GPU | ||
|
|
||
| .. testcode:: | ||
| :skipif: True | ||
|
|
||
| os.environ["CUDA_VISIBLE_DEVICES"] = "2,3" | ||
| ray.get_gpu_ids() == [2] | ||
| torch.cuda.is_available() == True | ||
| get_device() == torch.device("cuda:0") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this supposed to be
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cuda:0 is the "logical cuda device", which points to the 0th index in the CUDA_VISIBLE_DEVICES string: CUDA_VISIBLE_DEVICES[0] == "2" |
||
|
|
||
| Example: Launched 4 workers on the current node, each with 1 GPU | ||
|
|
||
| .. testcode:: | ||
| :skipif: True | ||
|
|
||
| os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" | ||
| ray.get_gpu_ids() == [2] | ||
| torch.cuda.is_available() == True | ||
| get_device() == torch.device("cuda:2") | ||
|
|
||
| Example: Launched 2 workers on the current node, each with 2 GPUs | ||
|
|
||
| .. testcode:: | ||
| :skipif: True | ||
|
|
||
| os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" | ||
| ray.get_gpu_ids() == [2,3] | ||
| torch.cuda.is_available() == True | ||
| get_device() == torch.device("cuda:2") | ||
|
|
||
|
|
||
| You can move a model to device by: | ||
|
|
||
| .. testcode:: | ||
| :skipif: True | ||
|
|
||
| model.to(ray.train.torch.get_device()) | ||
|
|
||
| Instead of manually checking the device type: | ||
|
|
||
| .. testcode:: | ||
| :skipif: True | ||
|
|
||
| model.to("cuda" if torch.cuda.is_available() else "cpu") | ||
| """ | ||
TimothySeah marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return get_devices()[0] | ||
|
|
||
|
|
||
| @PublicAPI(stability="beta") | ||
| @requires_train_worker() | ||
| def get_devices() -> List[torch.device]: | ||
| """Gets the correct torch device list configured for the current worker. | ||
justinvyu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| Assumes that `CUDA_VISIBLE_DEVICES` is set and is a | ||
| superset of the `ray.get_gpu_ids()`. | ||
|
|
||
| Returns: | ||
| The list of torch devices assigned to the current worker. | ||
|
|
||
| Examples: | ||
|
|
||
| Example: Launched 2 workers on the current node, each with 1 GPU | ||
|
|
||
| .. testcode:: | ||
| :skipif: True | ||
|
|
||
| os.environ["CUDA_VISIBLE_DEVICES"] == "2,3" | ||
| ray.get_gpu_ids() == [2] | ||
| torch.cuda.is_available() == True | ||
| get_devices() == [torch.device("cuda:0")] | ||
justinvyu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| Example: Launched 4 workers on the current node, each with 1 GPU | ||
|
|
||
| .. testcode:: | ||
| :skipif: True | ||
|
|
||
| os.environ["CUDA_VISIBLE_DEVICES"] == "0,1,2,3" | ||
| ray.get_gpu_ids() == [2] | ||
| torch.cuda.is_available() == True | ||
| get_devices() == [torch.device("cuda:2")] | ||
|
|
||
| Example: Launched 2 workers on the current node, each with 2 GPUs | ||
|
|
||
| .. testcode:: | ||
| :skipif: True | ||
|
|
||
| os.environ["CUDA_VISIBLE_DEVICES"] == "0,1,2,3" | ||
| ray.get_gpu_ids() == [2,3] | ||
| torch.cuda.is_available() == True | ||
| get_devices() == [torch.device("cuda:2"), torch.device("cuda:3")] | ||
| """ | ||
| if get_train_fn_utils().is_distributed(): | ||
| return get_devices_distributed() | ||
| else: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.