Skip to content

Commit 3d8a611

Browse files
authored
[Feature] Add the support for musa device support (#1453)
1 parent 109cd44 commit 3d8a611

File tree

22 files changed

+253
-43
lines changed

22 files changed

+253
-43
lines changed

mmengine/device/__init__.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2-
from .utils import (get_device, get_max_cuda_memory, is_cuda_available,
3-
is_dipu_available, is_mlu_available, is_mps_available,
4-
is_npu_available, is_npu_support_full_precision)
2+
from .utils import (get_device, get_max_cuda_memory, get_max_musa_memory,
3+
is_cuda_available, is_dipu_available, is_mlu_available,
4+
is_mps_available, is_musa_available, is_npu_available,
5+
is_npu_support_full_precision)
56

67
__all__ = [
78
'get_max_cuda_memory', 'get_device', 'is_cuda_available',
89
'is_mlu_available', 'is_mps_available', 'is_npu_available',
9-
'is_dipu_available', 'is_npu_support_full_precision'
10+
'is_dipu_available', 'get_max_musa_memory', 'is_musa_available',
11+
'is_npu_support_full_precision'
1012
]

mmengine/device/utils.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@
2222
except Exception:
2323
IS_DIPU_AVAILABLE = False
2424

25+
try:
26+
import torch_musa # noqa: F401
27+
IS_MUSA_AVAILABLE = True
28+
except Exception:
29+
IS_MUSA_AVAILABLE = False
30+
2531

2632
def get_max_cuda_memory(device: Optional[torch.device] = None) -> int:
2733
"""Returns the maximum GPU memory occupied by tensors in megabytes (MB) for
@@ -73,6 +79,34 @@ def is_dipu_available() -> bool:
7379
return IS_DIPU_AVAILABLE
7480

7581

82+
def get_max_musa_memory(device: Optional[torch.device] = None) -> int:
83+
"""Returns the maximum GPU memory occupied by tensors in megabytes (MB) for
84+
a given device. By default, this returns the peak allocated memory since
85+
the beginning of this program.
86+
87+
Args:
88+
device (torch.device, optional): selected device. Returns
89+
statistic for the current device, given by
90+
:func:`~torch.musa.current_device`, if ``device`` is None.
91+
Defaults to None.
92+
93+
Returns:
94+
int: The maximum GPU memory occupied by tensors in megabytes
95+
for a given device.
96+
"""
97+
mem = torch.musa.max_memory_allocated(device=device)
98+
mem_mb = torch.tensor([int(mem) // (1024 * 1024)],
99+
dtype=torch.int,
100+
device=device)
101+
# TODO:[email protected]: This function is not supported by musa yet.
102+
# torch.musa.reset_peak_memory_stats()
103+
return int(mem_mb.item())
104+
105+
106+
def is_musa_available() -> bool:
107+
return IS_MUSA_AVAILABLE
108+
109+
76110
def is_npu_support_full_precision() -> bool:
77111
"""Returns True if npu devices support full precision training."""
78112
version_of_support_full_precision = 220
@@ -91,12 +125,14 @@ def is_npu_support_full_precision() -> bool:
91125
DEVICE = 'mps'
92126
elif is_dipu_available():
93127
DEVICE = 'dipu'
128+
elif is_musa_available():
129+
DEVICE = 'musa'
94130

95131

96132
def get_device() -> str:
97133
"""Returns the currently existing device type.
98134
99135
Returns:
100-
str: cuda | npu | mlu | mps | cpu.
136+
str: cuda | npu | mlu | mps | musa | cpu.
101137
"""
102138
return DEVICE

mmengine/dist/dist.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,12 +415,16 @@ def _broadcast_object_list(object_list: List[Any],
415415
current_device = torch.device('cpu')
416416
is_hccl_backend = group_backend == 'hccl'
417417
is_cncl_backend = group_backend == 'cncl'
418+
is_mccl_backend = group_backend == 'mccl'
418419
if is_hccl_backend:
419420
current_device = torch.device('npu', torch.npu.current_device())
420421
object_sizes_tensor = object_sizes_tensor.to(current_device)
421422
elif is_cncl_backend:
422423
current_device = torch.device('mlu', torch.mlu.current_device())
423424
object_sizes_tensor = object_sizes_tensor.to(current_device)
425+
elif is_mccl_backend:
426+
current_device = torch.device('musa', torch.musa.current_device())
427+
object_sizes_tensor = object_sizes_tensor.to(current_device)
424428
elif is_nccl_backend:
425429
# See note about using torch.cuda.current_device() here in
426430
# docstring. We cannot simply use my_rank since rank == device is
@@ -624,13 +628,21 @@ def _all_gather_object(object_list: List[Any],
624628
group_backend = get_backend(group)
625629
current_device = torch.device('cpu')
626630
is_nccl_backend = group_backend == torch_dist.Backend.NCCL
631+
is_mccl_backend = group_backend == 'mccl'
627632
if is_nccl_backend:
628633
# See note about using torch.cuda.current_device() here in docstring.
629634
# We cannot simply use my_rank since rank == device is not necessarily
630635
# true.
631636
current_device = torch.device('cuda', torch.cuda.current_device())
632637
input_tensor = input_tensor.to(current_device)
633638
local_size = local_size.to(current_device)
639+
elif is_mccl_backend:
640+
# See note about using torch.musa.current_device() here in docstring.
641+
# We cannot simply use my_rank since rank == device is not necessarily
642+
# true.
643+
current_device = torch.device('musa', torch.musa.current_device())
644+
input_tensor = input_tensor.to(current_device)
645+
local_size = local_size.to(current_device)
634646
# Gather all local sizes. This is so that we can find the max size, and
635647
# index until the correct size when deserializing the tensors.
636648
group_size = get_world_size(group=group)
@@ -776,10 +788,15 @@ def _gather_object(obj: Any,
776788
group_backend = get_backend(group)
777789
current_device = torch.device('cpu')
778790
is_nccl_backend = group_backend == torch_dist.Backend.NCCL
791+
is_mccl_backend = group_backend == 'mccl'
779792
if is_nccl_backend:
780793
current_device = torch.device('cuda', torch.cuda.current_device())
781794
input_tensor = input_tensor.to(current_device)
782795
local_size = local_size.to(current_device)
796+
elif is_mccl_backend:
797+
current_device = torch.device('musa', torch.musa.current_device())
798+
input_tensor = input_tensor.to(current_device)
799+
local_size = local_size.to(current_device)
783800
# Gather all local sizes. This is so that we can find the max size, and
784801
# index until the correct size when deserializing the tensors.
785802
group_size = get_world_size(group=group)

mmengine/dist/utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from torch import Tensor
1212
from torch import distributed as torch_dist
1313
from torch.distributed import ProcessGroup
14-
from mmengine.device import is_mlu_available, is_npu_available
14+
from mmengine.device import (is_mlu_available, is_npu_available,
15+
is_musa_available)
1516

1617
from collections.abc import Iterable, Mapping
1718

@@ -117,6 +118,14 @@ def _init_dist_pytorch(backend, init_backend='torch', **kwargs) -> None:
117118
rank=rank,
118119
world_size=int(os.environ['WORLD_SIZE']),
119120
**kwargs)
121+
elif is_musa_available():
122+
import torch_musa # noqa: F401
123+
torch.musa.set_device(rank)
124+
torch_dist.init_process_group(
125+
backend='mccl',
126+
rank=rank,
127+
world_size=int(os.environ['WORLD_SIZE']),
128+
**kwargs)
120129
else:
121130
torch.cuda.set_device(local_rank)
122131

@@ -527,6 +536,9 @@ def get_comm_device(group: Optional[ProcessGroup] = None) -> torch.device:
527536
return torch.device('mlu', torch.mlu.current_device())
528537
elif backend == 'smddp':
529538
return torch.device('cuda', torch.cuda.current_device())
539+
elif backend == 'mccl':
540+
import torch_musa
541+
return torch.device('musa', torch_musa.current_device())
530542
else:
531543
# GLOO and MPI backends use cpu device by default
532544
return torch.device('cpu')

mmengine/hooks/empty_cache_hook.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55

66
from mmengine.registry import HOOKS
7+
from ..device import is_cuda_available, is_musa_available
78
from .hook import Hook
89

910
DATA_BATCH = Optional[Union[dict, tuple, list]]
@@ -49,7 +50,10 @@ def _after_iter(self,
4950
mode (str): Current mode of runner. Defaults to 'train'.
5051
"""
5152
if self._do_after_iter:
52-
torch.cuda.empty_cache()
53+
if is_cuda_available():
54+
torch.cuda.empty_cache()
55+
elif is_musa_available():
56+
torch.musa.empty_cache()
5357

5458
def _before_epoch(self, runner, mode: str = 'train') -> None:
5559
"""Empty cache before an epoch.
@@ -59,7 +63,10 @@ def _before_epoch(self, runner, mode: str = 'train') -> None:
5963
mode (str): Current mode of runner. Defaults to 'train'.
6064
"""
6165
if self._do_before_epoch:
62-
torch.cuda.empty_cache()
66+
if is_cuda_available():
67+
torch.cuda.empty_cache()
68+
elif is_musa_available():
69+
torch.musa.empty_cache()
6370

6471
def _after_epoch(self, runner, mode: str = 'train') -> None:
6572
"""Empty cache after an epoch.
@@ -69,4 +76,7 @@ def _after_epoch(self, runner, mode: str = 'train') -> None:
6976
mode (str): Current mode of runner. Defaults to 'train'.
7077
"""
7178
if self._do_after_epoch:
72-
torch.cuda.empty_cache()
79+
if is_cuda_available():
80+
torch.cuda.empty_cache()
81+
elif is_musa_available():
82+
torch.musa.empty_cache()

mmengine/logging/logger.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -398,22 +398,38 @@ def _get_device_id():
398398
except ImportError:
399399
return 0
400400
else:
401-
local_rank = int(os.getenv('LOCAL_RANK', '0'))
402-
# TODO: return device id of npu and mlu.
403-
if not torch.cuda.is_available():
404-
return local_rank
405-
cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', None)
406-
if cuda_visible_devices is None:
407-
num_device = torch.cuda.device_count()
408-
cuda_visible_devices = list(range(num_device))
409-
else:
410-
cuda_visible_devices = cuda_visible_devices.split(',')
401+
MUSA_AVAILABLE = False
411402
try:
412-
return int(cuda_visible_devices[local_rank])
413-
except ValueError:
414-
# handle case for Multi-Instance GPUs
415-
# see #1148 for details
416-
return cuda_visible_devices[local_rank]
403+
import torch_musa
404+
MUSA_AVAILABLE = True
405+
except ImportError:
406+
pass
407+
if MUSA_AVAILABLE:
408+
local_rank = int(os.getenv('LOCAL_RANK', '0'))
409+
musa_visible_devices = os.getenv('MUSA_VISIBLE_DEVICES', None)
410+
if musa_visible_devices is None:
411+
num_device = torch_musa.device_count()
412+
musa_visible_devices = list(range(num_device))
413+
else:
414+
musa_visible_devices = musa_visible_devices.split(',')
415+
return int(musa_visible_devices[local_rank])
416+
else:
417+
local_rank = int(os.getenv('LOCAL_RANK', '0'))
418+
# TODO: return device id of npu and mlu.
419+
if not torch.cuda.is_available():
420+
return local_rank
421+
cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', None)
422+
if cuda_visible_devices is None:
423+
num_device = torch.cuda.device_count()
424+
cuda_visible_devices = list(range(num_device))
425+
else:
426+
cuda_visible_devices = cuda_visible_devices.split(',')
427+
try:
428+
return int(cuda_visible_devices[local_rank])
429+
except ValueError:
430+
# handle case for Multi-Instance GPUs
431+
# see #1148 for details
432+
return cuda_visible_devices[local_rank]
417433

418434

419435
def _get_host_info() -> str:

mmengine/model/base_model/base_model.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,21 @@ def cuda(
222222
self._set_device(torch.device(device))
223223
return super().cuda(device)
224224

225+
def musa(
226+
self,
227+
device: Optional[Union[int, str, torch.device]] = None,
228+
) -> nn.Module:
229+
"""Overrides this method to call :meth:`BaseDataPreprocessor.musa`
230+
additionally.
231+
232+
Returns:
233+
nn.Module: The model itself.
234+
"""
235+
if device is None or isinstance(device, int):
236+
device = torch.device('musa', index=device)
237+
self._set_device(torch.device(device))
238+
return super().musa(device)
239+
225240
def mlu(
226241
self,
227242
device: Union[int, str, torch.device, None] = None,

mmengine/model/base_model/data_preprocessor.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,15 @@ def cuda(self, *args, **kwargs) -> nn.Module:
113113
self._device = torch.device(torch.cuda.current_device())
114114
return super().cuda()
115115

116+
def musa(self, *args, **kwargs) -> nn.Module:
117+
"""Overrides this method to set the :attr:`device`
118+
119+
Returns:
120+
nn.Module: The model itself.
121+
"""
122+
self._device = torch.device(torch.musa.current_device())
123+
return super().musa()
124+
116125
def npu(self, *args, **kwargs) -> nn.Module:
117126
"""Overrides this method to set the :attr:`device`
118127

mmengine/optim/optimizer/amp_optimizer_wrapper.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch.nn as nn
77

88
from mmengine.device import (is_cuda_available, is_mlu_available,
9-
is_npu_available)
9+
is_musa_available, is_npu_available)
1010
from mmengine.registry import OPTIM_WRAPPERS
1111
from mmengine.utils import digit_version
1212
from mmengine.utils.dl_utils import TORCH_VERSION
@@ -74,8 +74,9 @@ def __init__(self,
7474
assert digit_version(TORCH_VERSION) >= digit_version('1.6.0'), (
7575
'`torch.cuda.amp` is only available when pytorch version >= 1.6')
7676
assert is_cuda_available() or is_npu_available() or is_mlu_available(
77-
), ('``AmpOptimizerWrapper`` is only available training '
78-
'on gpu, npu or mlu')
77+
) or is_musa_available(), (
78+
'``AmpOptimizerWrapper`` is only available training '
79+
'on gpu, npu, mlu or musa')
7980
super().__init__(**kwargs)
8081
self._scale_update_param = None
8182

mmengine/runner/amp.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,13 @@ def autocast(device_type: Optional[str] = None,
135135

136136
elif device_type == 'npu':
137137
pass
138-
138+
elif device_type == 'musa':
139+
if dtype is None:
140+
dtype = torch.get_autocast_gpu_dtype()
141+
with torch.musa.amp.autocast(
142+
enabled=enabled, dtype=dtype, cache_enabled=cache_enabled):
143+
yield
144+
return
139145
else:
140146
# Device like MPS does not support fp16 training or testing.
141147
# If an inappropriate device is set and fp16 is enabled, an error

0 commit comments

Comments
 (0)