|
22 | 22 | get_devices as get_devices_distributed, |
23 | 23 | ) |
24 | 24 | from ray.train.v2._internal.execution.train_fn_utils import get_train_fn_utils |
| 25 | +from ray.train.v2._internal.util import requires_train_worker |
25 | 26 | from ray.util.annotations import Deprecated, PublicAPI |
26 | 27 |
|
27 | 28 | logger = logging.getLogger(__name__) |
|
38 | 39 | ) |
39 | 40 |
|
40 | 41 |
|
| 42 | +@PublicAPI(stability="stable") |
| 43 | +@requires_train_worker() |
41 | 44 | def get_device() -> torch.device: |
| 45 | + """Gets the correct torch device configured for the current worker. |
| 46 | +
|
| 47 | + Returns the torch device for the current worker. If more than 1 GPU is |
| 48 | + requested per worker, returns the device with the lowest device index. |
| 49 | +
|
| 50 | + .. note:: |
| 51 | +
|
| 52 | + If you requested multiple GPUs per worker, and want to get |
| 53 | + the full list of torch devices, please use |
| 54 | + :meth:`~ray.train.torch.get_devices`. |
| 55 | +
|
| 56 | + Assumes that `CUDA_VISIBLE_DEVICES` is set and is a |
| 57 | + superset of the `ray.get_gpu_ids()`. |
| 58 | +
|
| 59 | + Returns: |
| 60 | + The torch device assigned to the current worker. |
| 61 | +
|
| 62 | + Examples: |
| 63 | +
|
| 64 | + Example: Launched 2 workers on the current node, each with 1 GPU |
| 65 | +
|
| 66 | + .. testcode:: |
| 67 | + :skipif: True |
| 68 | +
|
| 69 | + os.environ["CUDA_VISIBLE_DEVICES"] = "2,3" |
| 70 | + ray.get_gpu_ids() == [2] |
| 71 | + torch.cuda.is_available() == True |
| 72 | + get_device() == torch.device("cuda:0") |
| 73 | +
|
| 74 | + Example: Launched 4 workers on the current node, each with 1 GPU |
| 75 | +
|
| 76 | + .. testcode:: |
| 77 | + :skipif: True |
| 78 | +
|
| 79 | + os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" |
| 80 | + ray.get_gpu_ids() == [2] |
| 81 | + torch.cuda.is_available() == True |
| 82 | + get_device() == torch.device("cuda:2") |
| 83 | +
|
| 84 | + Example: Launched 2 workers on the current node, each with 2 GPUs |
| 85 | +
|
| 86 | + .. testcode:: |
| 87 | + :skipif: True |
| 88 | +
|
| 89 | + os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" |
| 90 | + ray.get_gpu_ids() == [2,3] |
| 91 | + torch.cuda.is_available() == True |
| 92 | + get_device() == torch.device("cuda:2") |
| 93 | +
|
| 94 | +
|
| 95 | + You can move a model to device by: |
| 96 | +
|
| 97 | + .. testcode:: |
| 98 | + :skipif: True |
| 99 | +
|
| 100 | + model.to(ray.train.torch.get_device()) |
| 101 | +
|
| 102 | + Instead of manually checking the device type: |
| 103 | +
|
| 104 | + .. testcode:: |
| 105 | + :skipif: True |
| 106 | +
|
| 107 | + model.to("cuda" if torch.cuda.is_available() else "cpu") |
| 108 | + """ |
42 | 109 | return get_devices()[0] |
43 | 110 |
|
44 | 111 |
|
| 112 | +@PublicAPI(stability="beta") |
| 113 | +@requires_train_worker() |
45 | 114 | def get_devices() -> List[torch.device]: |
| 115 | + """Gets the list of torch devices configured for the current worker. |
| 116 | +
|
| 117 | + Assumes that `CUDA_VISIBLE_DEVICES` is set and is a |
| 118 | + superset of the `ray.get_gpu_ids()`. |
| 119 | +
|
| 120 | + Returns: |
| 121 | + The list of torch devices assigned to the current worker. |
| 122 | +
|
| 123 | + Examples: |
| 124 | +
|
| 125 | + Example: Launched 2 workers on the current node, each with 1 GPU |
| 126 | +
|
| 127 | + .. testcode:: |
| 128 | + :skipif: True |
| 129 | +
|
| 130 | + os.environ["CUDA_VISIBLE_DEVICES"] == "2,3" |
| 131 | + ray.get_gpu_ids() == [2] |
| 132 | + torch.cuda.is_available() == True |
| 133 | + get_devices() == [torch.device("cuda:0")] |
| 134 | +
|
| 135 | + Example: Launched 4 workers on the current node, each with 1 GPU |
| 136 | +
|
| 137 | + .. testcode:: |
| 138 | + :skipif: True |
| 139 | +
|
| 140 | + os.environ["CUDA_VISIBLE_DEVICES"] == "0,1,2,3" |
| 141 | + ray.get_gpu_ids() == [2] |
| 142 | + torch.cuda.is_available() == True |
| 143 | + get_devices() == [torch.device("cuda:2")] |
| 144 | +
|
| 145 | + Example: Launched 2 workers on the current node, each with 2 GPUs |
| 146 | +
|
| 147 | + .. testcode:: |
| 148 | + :skipif: True |
| 149 | +
|
| 150 | + os.environ["CUDA_VISIBLE_DEVICES"] == "0,1,2,3" |
| 151 | + ray.get_gpu_ids() == [2,3] |
| 152 | + torch.cuda.is_available() == True |
| 153 | + get_devices() == [torch.device("cuda:2"), torch.device("cuda:3")] |
| 154 | + """ |
46 | 155 | if get_train_fn_utils().is_distributed(): |
47 | 156 | return get_devices_distributed() |
48 | 157 | else: |
|
0 commit comments