diff --git a/mmengine/device/__init__.py b/mmengine/device/__init__.py index bfd82a858a..88937d5592 100644 --- a/mmengine/device/__init__.py +++ b/mmengine/device/__init__.py @@ -1,10 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .utils import (get_device, get_max_cuda_memory, is_cuda_available, - is_dipu_available, is_mlu_available, is_mps_available, - is_npu_available, is_npu_support_full_precision) +from .utils import (get_device, get_max_cuda_memory, get_max_musa_memory, + is_cuda_available, is_dipu_available, is_mlu_available, + is_mps_available, is_musa_available, is_npu_available, + is_npu_support_full_precision) __all__ = [ 'get_max_cuda_memory', 'get_device', 'is_cuda_available', 'is_mlu_available', 'is_mps_available', 'is_npu_available', - 'is_dipu_available', 'is_npu_support_full_precision' + 'is_dipu_available', 'get_max_musa_memory', 'is_musa_available', + 'is_npu_support_full_precision' ] diff --git a/mmengine/device/utils.py b/mmengine/device/utils.py index 0bb69d2ea9..2fd56d80ed 100644 --- a/mmengine/device/utils.py +++ b/mmengine/device/utils.py @@ -22,6 +22,12 @@ except Exception: IS_DIPU_AVAILABLE = False +try: + import torch_musa # noqa: F401 + IS_MUSA_AVAILABLE = True +except Exception: + IS_MUSA_AVAILABLE = False + def get_max_cuda_memory(device: Optional[torch.device] = None) -> int: """Returns the maximum GPU memory occupied by tensors in megabytes (MB) for @@ -73,6 +79,34 @@ def is_dipu_available() -> bool: return IS_DIPU_AVAILABLE +def get_max_musa_memory(device: Optional[torch.device] = None) -> int: + """Returns the maximum GPU memory occupied by tensors in megabytes (MB) for + a given device. By default, this returns the peak allocated memory since + the beginning of this program. + + Args: + device (torch.device, optional): selected device. Returns + statistic for the current device, given by + :func:`~torch.musa.current_device`, if ``device`` is None. + Defaults to None. + + Returns: + int: The maximum GPU memory occupied by tensors in megabytes + for a given device. + """ + mem = torch.musa.max_memory_allocated(device=device) + mem_mb = torch.tensor([int(mem) // (1024 * 1024)], + dtype=torch.int, + device=device) + # TODO:haowen.han@mthreads.com: This function is not supported by musa yet. + # torch.musa.reset_peak_memory_stats() + return int(mem_mb.item()) + + +def is_musa_available() -> bool: + return IS_MUSA_AVAILABLE + + def is_npu_support_full_precision() -> bool: """Returns True if npu devices support full precision training.""" version_of_support_full_precision = 220 @@ -91,12 +125,14 @@ def is_npu_support_full_precision() -> bool: DEVICE = 'mps' elif is_dipu_available(): DEVICE = 'dipu' +elif is_musa_available(): + DEVICE = 'musa' def get_device() -> str: """Returns the currently existing device type. Returns: - str: cuda | npu | mlu | mps | cpu. + str: cuda | npu | mlu | mps | musa | cpu. """ return DEVICE diff --git a/mmengine/dist/dist.py b/mmengine/dist/dist.py index b6dd769f90..829e4b9250 100644 --- a/mmengine/dist/dist.py +++ b/mmengine/dist/dist.py @@ -415,12 +415,16 @@ def _broadcast_object_list(object_list: List[Any], current_device = torch.device('cpu') is_hccl_backend = group_backend == 'hccl' is_cncl_backend = group_backend == 'cncl' + is_mccl_backend = group_backend == 'mccl' if is_hccl_backend: current_device = torch.device('npu', torch.npu.current_device()) object_sizes_tensor = object_sizes_tensor.to(current_device) elif is_cncl_backend: current_device = torch.device('mlu', torch.mlu.current_device()) object_sizes_tensor = object_sizes_tensor.to(current_device) + elif is_mccl_backend: + current_device = torch.device('musa', torch.musa.current_device()) + object_sizes_tensor = object_sizes_tensor.to(current_device) elif is_nccl_backend: # See note about using torch.cuda.current_device() here in # docstring. We cannot simply use my_rank since rank == device is @@ -624,6 +628,7 @@ def _all_gather_object(object_list: List[Any], group_backend = get_backend(group) current_device = torch.device('cpu') is_nccl_backend = group_backend == torch_dist.Backend.NCCL + is_mccl_backend = group_backend == 'mccl' if is_nccl_backend: # See note about using torch.cuda.current_device() here in docstring. # We cannot simply use my_rank since rank == device is not necessarily @@ -631,6 +636,13 @@ def _all_gather_object(object_list: List[Any], current_device = torch.device('cuda', torch.cuda.current_device()) input_tensor = input_tensor.to(current_device) local_size = local_size.to(current_device) + elif is_mccl_backend: + # See note about using torch.musa.current_device() here in docstring. + # We cannot simply use my_rank since rank == device is not necessarily + # true. + current_device = torch.device('musa', torch.musa.current_device()) + input_tensor = input_tensor.to(current_device) + local_size = local_size.to(current_device) # Gather all local sizes. This is so that we can find the max size, and # index until the correct size when deserializing the tensors. group_size = get_world_size(group=group) @@ -776,10 +788,15 @@ def _gather_object(obj: Any, group_backend = get_backend(group) current_device = torch.device('cpu') is_nccl_backend = group_backend == torch_dist.Backend.NCCL + is_mccl_backend = group_backend == 'mccl' if is_nccl_backend: current_device = torch.device('cuda', torch.cuda.current_device()) input_tensor = input_tensor.to(current_device) local_size = local_size.to(current_device) + elif is_mccl_backend: + current_device = torch.device('musa', torch.musa.current_device()) + input_tensor = input_tensor.to(current_device) + local_size = local_size.to(current_device) # Gather all local sizes. This is so that we can find the max size, and # index until the correct size when deserializing the tensors. group_size = get_world_size(group=group) diff --git a/mmengine/dist/utils.py b/mmengine/dist/utils.py index d1d19d8f68..d99e837ae4 100644 --- a/mmengine/dist/utils.py +++ b/mmengine/dist/utils.py @@ -11,7 +11,8 @@ from torch import Tensor from torch import distributed as torch_dist from torch.distributed import ProcessGroup -from mmengine.device import is_mlu_available, is_npu_available +from mmengine.device import (is_mlu_available, is_npu_available, + is_musa_available) from collections.abc import Iterable, Mapping @@ -116,6 +117,14 @@ def _init_dist_pytorch(backend, init_backend='torch', **kwargs) -> None: rank=rank, world_size=int(os.environ['WORLD_SIZE']), **kwargs) + elif is_musa_available(): + import torch_musa # noqa: F401 + torch.musa.set_device(rank) + torch_dist.init_process_group( + backend='mccl', + rank=rank, + world_size=int(os.environ['WORLD_SIZE']), + **kwargs) else: # LOCAL_RANK is set by `torch.distributed.launch` since PyTorch 1.1 local_rank = int(os.environ['LOCAL_RANK']) @@ -528,6 +537,9 @@ def get_comm_device(group: Optional[ProcessGroup] = None) -> torch.device: return torch.device('mlu', torch.mlu.current_device()) elif backend == 'smddp': return torch.device('cuda', torch.cuda.current_device()) + elif backend == 'mccl': + import torch_musa + return torch.device('musa', torch_musa.current_device()) else: # GLOO and MPI backends use cpu device by default return torch.device('cpu') diff --git a/mmengine/hooks/empty_cache_hook.py b/mmengine/hooks/empty_cache_hook.py index b9b5eba0ed..9a92cdebfe 100644 --- a/mmengine/hooks/empty_cache_hook.py +++ b/mmengine/hooks/empty_cache_hook.py @@ -4,6 +4,7 @@ import torch from mmengine.registry import HOOKS +from ..device import is_cuda_available, is_musa_available from .hook import Hook DATA_BATCH = Optional[Union[dict, tuple, list]] @@ -49,7 +50,10 @@ def _after_iter(self, mode (str): Current mode of runner. Defaults to 'train'. """ if self._do_after_iter: - torch.cuda.empty_cache() + if is_cuda_available(): + torch.cuda.empty_cache() + elif is_musa_available(): + torch.musa.empty_cache() def _before_epoch(self, runner, mode: str = 'train') -> None: """Empty cache before an epoch. @@ -59,7 +63,10 @@ def _before_epoch(self, runner, mode: str = 'train') -> None: mode (str): Current mode of runner. Defaults to 'train'. """ if self._do_before_epoch: - torch.cuda.empty_cache() + if is_cuda_available(): + torch.cuda.empty_cache() + elif is_musa_available(): + torch.musa.empty_cache() def _after_epoch(self, runner, mode: str = 'train') -> None: """Empty cache after an epoch. @@ -69,4 +76,7 @@ def _after_epoch(self, runner, mode: str = 'train') -> None: mode (str): Current mode of runner. Defaults to 'train'. """ if self._do_after_epoch: - torch.cuda.empty_cache() + if is_cuda_available(): + torch.cuda.empty_cache() + elif is_musa_available(): + torch.musa.empty_cache() diff --git a/mmengine/logging/logger.py b/mmengine/logging/logger.py index 9b2cb9da66..839a08cdda 100644 --- a/mmengine/logging/logger.py +++ b/mmengine/logging/logger.py @@ -398,22 +398,38 @@ def _get_device_id(): except ImportError: return 0 else: - local_rank = int(os.getenv('LOCAL_RANK', '0')) - # TODO: return device id of npu and mlu. - if not torch.cuda.is_available(): - return local_rank - cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', None) - if cuda_visible_devices is None: - num_device = torch.cuda.device_count() - cuda_visible_devices = list(range(num_device)) - else: - cuda_visible_devices = cuda_visible_devices.split(',') + MUSA_AVAILABLE = False try: - return int(cuda_visible_devices[local_rank]) - except ValueError: - # handle case for Multi-Instance GPUs - # see #1148 for details - return cuda_visible_devices[local_rank] + import torch_musa + MUSA_AVAILABLE = True + except ImportError: + pass + if MUSA_AVAILABLE: + local_rank = int(os.getenv('LOCAL_RANK', '0')) + musa_visible_devices = os.getenv('MUSA_VISIBLE_DEVICES', None) + if musa_visible_devices is None: + num_device = torch_musa.device_count() + musa_visible_devices = list(range(num_device)) + else: + musa_visible_devices = musa_visible_devices.split(',') + return int(musa_visible_devices[local_rank]) + else: + local_rank = int(os.getenv('LOCAL_RANK', '0')) + # TODO: return device id of npu and mlu. + if not torch.cuda.is_available(): + return local_rank + cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', None) + if cuda_visible_devices is None: + num_device = torch.cuda.device_count() + cuda_visible_devices = list(range(num_device)) + else: + cuda_visible_devices = cuda_visible_devices.split(',') + try: + return int(cuda_visible_devices[local_rank]) + except ValueError: + # handle case for Multi-Instance GPUs + # see #1148 for details + return cuda_visible_devices[local_rank] def _get_host_info() -> str: diff --git a/mmengine/model/base_model/base_model.py b/mmengine/model/base_model/base_model.py index 14c91eb6ca..299cd67557 100644 --- a/mmengine/model/base_model/base_model.py +++ b/mmengine/model/base_model/base_model.py @@ -222,6 +222,21 @@ def cuda( self._set_device(torch.device(device)) return super().cuda(device) + def musa( + self, + device: Optional[Union[int, str, torch.device]] = None, + ) -> nn.Module: + """Overrides this method to call :meth:`BaseDataPreprocessor.musa` + additionally. + + Returns: + nn.Module: The model itself. + """ + if device is None or isinstance(device, int): + device = torch.device('musa', index=device) + self._set_device(torch.device(device)) + return super().musa(device) + def mlu( self, device: Union[int, str, torch.device, None] = None, diff --git a/mmengine/model/base_model/data_preprocessor.py b/mmengine/model/base_model/data_preprocessor.py index 1f285aca62..af84246874 100644 --- a/mmengine/model/base_model/data_preprocessor.py +++ b/mmengine/model/base_model/data_preprocessor.py @@ -113,6 +113,15 @@ def cuda(self, *args, **kwargs) -> nn.Module: self._device = torch.device(torch.cuda.current_device()) return super().cuda() + def musa(self, *args, **kwargs) -> nn.Module: + """Overrides this method to set the :attr:`device` + + Returns: + nn.Module: The model itself. + """ + self._device = torch.device(torch.musa.current_device()) + return super().musa() + def npu(self, *args, **kwargs) -> nn.Module: """Overrides this method to set the :attr:`device` diff --git a/mmengine/optim/optimizer/amp_optimizer_wrapper.py b/mmengine/optim/optimizer/amp_optimizer_wrapper.py index 7a82d16603..4f3323f2cc 100644 --- a/mmengine/optim/optimizer/amp_optimizer_wrapper.py +++ b/mmengine/optim/optimizer/amp_optimizer_wrapper.py @@ -6,7 +6,7 @@ import torch.nn as nn from mmengine.device import (is_cuda_available, is_mlu_available, - is_npu_available) + is_musa_available, is_npu_available) from mmengine.registry import OPTIM_WRAPPERS from mmengine.utils import digit_version from mmengine.utils.dl_utils import TORCH_VERSION @@ -74,8 +74,9 @@ def __init__(self, assert digit_version(TORCH_VERSION) >= digit_version('1.6.0'), ( '`torch.cuda.amp` is only available when pytorch version >= 1.6') assert is_cuda_available() or is_npu_available() or is_mlu_available( - ), ('``AmpOptimizerWrapper`` is only available training ' - 'on gpu, npu or mlu') + ) or is_musa_available(), ( + '``AmpOptimizerWrapper`` is only available training ' + 'on gpu, npu, mlu or musa') super().__init__(**kwargs) self._scale_update_param = None diff --git a/mmengine/runner/amp.py b/mmengine/runner/amp.py index 964518fc90..198babc582 100644 --- a/mmengine/runner/amp.py +++ b/mmengine/runner/amp.py @@ -135,7 +135,13 @@ def autocast(device_type: Optional[str] = None, elif device_type == 'npu': pass - + elif device_type == 'musa': + if dtype is None: + dtype = torch.get_autocast_gpu_dtype() + with torch.musa.amp.autocast( + enabled=enabled, dtype=dtype, cache_enabled=cache_enabled): + yield + return else: # Device like MPS does not support fp16 training or testing. # If an inappropriate device is set and fp16 is enabled, an error diff --git a/mmengine/runner/log_processor.py b/mmengine/runner/log_processor.py index 0453377d0f..98183ae317 100644 --- a/mmengine/runner/log_processor.py +++ b/mmengine/runner/log_processor.py @@ -9,7 +9,8 @@ import numpy as np import torch -from mmengine.device import get_max_cuda_memory, is_cuda_available +from mmengine.device import (get_max_cuda_memory, get_max_musa_memory, + is_cuda_available, is_musa_available) from mmengine.registry import LOG_PROCESSORS @@ -226,11 +227,13 @@ def get_log_after_iter(self, runner, batch_idx: int, log_tag.pop('time') log_tag.pop('data_time') - # If cuda is available, the max memory occupied should be calculated. - if is_cuda_available(): + # If cuda/musa is available, + # the max memory occupied should be calculated. + if is_cuda_available() or is_musa_available(): max_memory = self._get_max_memory(runner) log_str += f'memory: {max_memory} ' tag['memory'] = max_memory + # Loop left keys to fill `log_str`. if mode in ('train', 'val'): log_items = [] @@ -498,6 +501,9 @@ def _get_max_memory(self, runner) -> int: """ device = getattr(runner.model, 'output_device', None) + + if is_musa_available(): + return get_max_musa_memory(device) return get_max_cuda_memory(device) def _get_iter(self, runner, batch_idx: int) -> int: diff --git a/mmengine/runner/utils.py b/mmengine/runner/utils.py index d7098c7295..b91025eb07 100644 --- a/mmengine/runner/utils.py +++ b/mmengine/runner/utils.py @@ -7,6 +7,7 @@ import torch from torch.utils.data import DataLoader +from mmengine.device import is_cuda_available, is_musa_available from mmengine.dist import get_rank, sync_random_seed from mmengine.logging import print_log from mmengine.utils import digit_version, is_list_of @@ -69,7 +70,10 @@ def set_random_seed(seed: Optional[int] = None, np.random.seed(seed) torch.manual_seed(seed) # torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) + if is_cuda_available(): + torch.cuda.manual_seed_all(seed) + elif is_musa_available(): + torch.musa.manual_seed_all(seed) # os.environ['PYTHONHASHSEED'] = str(seed) if deterministic: if torch.backends.cudnn.benchmark: diff --git a/mmengine/structures/base_data_element.py b/mmengine/structures/base_data_element.py index 454a224371..53bcd5babf 100644 --- a/mmengine/structures/base_data_element.py +++ b/mmengine/structures/base_data_element.py @@ -510,6 +510,17 @@ def cuda(self) -> 'BaseDataElement': new_data.set_data(data) return new_data + # Tensor-like methods + def musa(self) -> 'BaseDataElement': + """Convert all tensors to musa in data.""" + new_data = self.new() + for k, v in self.items(): + if isinstance(v, (torch.Tensor, BaseDataElement)): + v = v.musa() + data = {k: v} + new_data.set_data(data) + return new_data + # Tensor-like methods def npu(self) -> 'BaseDataElement': """Convert all tensors to NPU in data.""" diff --git a/mmengine/structures/instance_data.py b/mmengine/structures/instance_data.py index 8df9727a00..369d445f28 100644 --- a/mmengine/structures/instance_data.py +++ b/mmengine/structures/instance_data.py @@ -18,6 +18,9 @@ elif get_device() == 'mlu': BoolTypeTensor = Union[torch.BoolTensor, torch.mlu.BoolTensor] LongTypeTensor = Union[torch.LongTensor, torch.mlu.LongTensor] +elif get_device() == 'musa': + BoolTypeTensor = Union[torch.BoolTensor, torch.musa.BoolTensor] + LongTypeTensor = Union[torch.LongTensor, torch.musa.LongTensor] else: BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor] LongTypeTensor = Union[torch.LongTensor, torch.cuda.LongTensor] diff --git a/mmengine/utils/dl_utils/collect_env.py b/mmengine/utils/dl_utils/collect_env.py index 6406677a73..61b45ec50e 100644 --- a/mmengine/utils/dl_utils/collect_env.py +++ b/mmengine/utils/dl_utils/collect_env.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. """This file holding some environment constant for sharing by other files.""" +import os import os.path as osp import subprocess import sys @@ -9,6 +10,7 @@ import torch import mmengine +from mmengine.device import is_cuda_available, is_musa_available from .parrots_wrapper import TORCH_VERSION, get_build_config, is_rocm_pytorch @@ -24,6 +26,10 @@ def _get_cuda_home(): return CUDA_HOME +def _get_musa_home(): + return os.environ.get('MUSA_HOME') + + def collect_env(): """Collect the information of the running environments. @@ -51,9 +57,10 @@ def collect_env(): env_info['sys.platform'] = sys.platform env_info['Python'] = sys.version.replace('\n', '') - cuda_available = torch.cuda.is_available() + cuda_available = is_cuda_available() + musa_available = is_musa_available() env_info['CUDA available'] = cuda_available - + env_info['MUSA available'] = musa_available env_info['numpy_random_seed'] = np.random.get_state()[1][0] if cuda_available: @@ -89,7 +96,23 @@ def collect_env(): except subprocess.SubprocessError: nvcc = 'Not Available' env_info['NVCC'] = nvcc + elif musa_available: + devices = defaultdict(list) + for k in range(torch.musa.device_count()): + devices[torch.musa.get_device_name(k)].append(str(k)) + for name, device_ids in devices.items(): + env_info['GPU ' + ','.join(device_ids)] = name + + MUSA_HOME = _get_musa_home() + env_info['MUSA_HOME'] = MUSA_HOME + if MUSA_HOME is not None and osp.isdir(MUSA_HOME): + try: + mcc = osp.join(MUSA_HOME, 'bin/mcc') + subprocess.check_output(f'"{mcc}" -v', shell=True) + except subprocess.SubprocessError: + mcc = 'Not Available' + env_info['mcc'] = mcc try: # Check C++ Compiler. # For Unix-like, sysconfig has 'CC' variable like 'gcc -pthread ...', diff --git a/mmengine/utils/dl_utils/time_counter.py b/mmengine/utils/dl_utils/time_counter.py index 4a1fb42ee0..e4a155dd72 100644 --- a/mmengine/utils/dl_utils/time_counter.py +++ b/mmengine/utils/dl_utils/time_counter.py @@ -4,6 +4,7 @@ import torch +from mmengine.device import is_cuda_available, is_musa_available from mmengine.dist.utils import master_only from mmengine.logging import MMLogger, print_log @@ -84,15 +85,20 @@ def __call__(self, fn): def wrapper(*args, **kwargs): self.__count += 1 - if self.with_sync and torch.cuda.is_available(): - torch.cuda.synchronize() + if self.with_sync: + if is_cuda_available(): + torch.cuda.synchronize() + elif is_musa_available(): + torch.musa.synchronize() start_time = time.perf_counter() result = fn(*args, **kwargs) - if self.with_sync and torch.cuda.is_available(): - torch.cuda.synchronize() - + if self.with_sync: + if is_cuda_available(): + torch.cuda.synchronize() + elif is_musa_available(): + torch.musa.synchronize() elapsed = time.perf_counter() - start_time self.print_time(elapsed) diff --git a/tests/test_device/test_device.py b/tests/test_device/test_device.py index 19bd1f7f19..d2171afa58 100644 --- a/tests/test_device/test_device.py +++ b/tests/test_device/test_device.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmengine.device import (get_device, is_cuda_available, is_mlu_available, - is_mps_available, is_npu_available) + is_mps_available, is_musa_available, + is_npu_available) def test_get_device(): @@ -13,5 +14,7 @@ def test_get_device(): assert device == 'mlu' elif is_mps_available(): assert device == 'mps' + elif is_musa_available(): + assert device == 'musa' else: assert device == 'cpu' diff --git a/tests/test_dist/test_dist.py b/tests/test_dist/test_dist.py index d89f5eb878..a2ef07b713 100644 --- a/tests/test_dist/test_dist.py +++ b/tests/test_dist/test_dist.py @@ -11,6 +11,7 @@ import torch.distributed as torch_dist import mmengine.dist as dist +from mmengine.device import is_musa_available from mmengine.dist.dist import sync_random_seed from mmengine.testing._internal import MultiProcessTestCase from mmengine.utils import digit_version @@ -117,6 +118,7 @@ def test_all_reduce_params(self): self.assertTrue(torch.allclose(item1, item2)) +@unittest.skipIf(is_musa_available(), reason='musa do not support gloo yet') class TestDistWithGLOOBackend(MultiProcessTestCase): def _init_dist_env(self, rank, world_size): diff --git a/tests/test_hooks/test_ema_hook.py b/tests/test_hooks/test_ema_hook.py index 4ceebe9088..6dad7ba4f0 100644 --- a/tests/test_hooks/test_ema_hook.py +++ b/tests/test_hooks/test_ema_hook.py @@ -1,11 +1,13 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy import os.path as osp +import unittest import torch import torch.nn as nn from mmengine.config import ConfigDict +from mmengine.device import is_musa_available from mmengine.hooks import EMAHook from mmengine.model import BaseModel, ExponentialMovingAverage from mmengine.registry import MODELS @@ -45,6 +47,9 @@ def forward(self, *args, **kwargs): return super().forward(*args, **kwargs) +# TODO:haowen.han@mtheads.com +@unittest.skipIf(is_musa_available(), + "musa backend do not support 'aten::lerp.Scalar_out'") class TestEMAHook(RunnerTestCase): def setUp(self) -> None: diff --git a/tests/test_hooks/test_empty_cache_hook.py b/tests/test_hooks/test_empty_cache_hook.py index 4a9ea99752..d30972d360 100644 --- a/tests/test_hooks/test_empty_cache_hook.py +++ b/tests/test_hooks/test_empty_cache_hook.py @@ -1,11 +1,16 @@ # Copyright (c) OpenMMLab. All rights reserved. from unittest.mock import patch +import pytest + +from mmengine.device import is_cuda_available from mmengine.testing import RunnerTestCase class TestEmptyCacheHook(RunnerTestCase): + @pytest.mark.skipif( + not is_cuda_available(), reason='cuda should be available') def test_with_runner(self): with patch('torch.cuda.empty_cache') as mock_empty_cache: cfg = self.epoch_based_cfg diff --git a/tests/test_runner/test_amp.py b/tests/test_runner/test_amp.py index a80c7f35cb..7208e25079 100644 --- a/tests/test_runner/test_amp.py +++ b/tests/test_runner/test_amp.py @@ -5,7 +5,8 @@ import torch.nn as nn import mmengine -from mmengine.device import get_device, is_mlu_available, is_npu_available +from mmengine.device import (get_device, is_mlu_available, is_musa_available, + is_npu_available) from mmengine.runner import autocast from mmengine.utils import digit_version from mmengine.utils.dl_utils import TORCH_VERSION @@ -44,6 +45,21 @@ def test_autocast(self): layer = nn.Conv2d(1, 1, 1).to(device) res = layer(torch.randn(1, 1, 1, 1).to(device)) self.assertEqual(res.dtype, torch.float32) + elif is_musa_available(): + device = 'musa' + with autocast(device_type=device): + # torch.autocast support mlu mode. + layer = nn.Conv2d(1, 1, 1).to(device) + res = layer(torch.randn(1, 1, 1, 1).to(device)) + self.assertIn(res.dtype, (torch.bfloat16, torch.float16)) + with autocast(enabled=False, device_type=device): + res = layer(torch.randn(1, 1, 1, 1).to(device)) + self.assertEqual(res.dtype, torch.float32) + # Test with fp32_enabled + with autocast(enabled=False, device_type=device): + layer = nn.Conv2d(1, 1, 1).to(device) + res = layer(torch.randn(1, 1, 1, 1).to(device)) + self.assertEqual(res.dtype, torch.float32) elif not torch.cuda.is_available(): if digit_version(TORCH_VERSION) < digit_version('1.10.0'): # `torch.cuda.amp.autocast` is only support in gpu mode, if diff --git a/tests/test_runner/test_log_processor.py b/tests/test_runner/test_log_processor.py index 9b93a9a8ea..d7fae5722a 100644 --- a/tests/test_runner/test_log_processor.py +++ b/tests/test_runner/test_log_processor.py @@ -7,6 +7,7 @@ import torch from parameterized import parameterized +from mmengine.device import is_cuda_available, is_musa_available from mmengine.logging import HistoryBuffer, MessageHub, MMLogger from mmengine.runner import LogProcessor from mmengine.testing import RunnerTestCase @@ -113,7 +114,7 @@ def test_get_log_after_iter(self, by_epoch, mode, log_with_hierarchy): f"time: {train_logs['time']:.4f} " f"data_time: {train_logs['data_time']:.4f} ") - if torch.cuda.is_available(): + if is_cuda_available() or is_musa_available(): log_str += 'memory: 100 ' if mode == 'train': log_str += f"loss_cls: {train_logs['loss_cls']:.4f}" @@ -141,7 +142,7 @@ def test_get_log_after_iter(self, by_epoch, mode, log_with_hierarchy): f"time: {train_logs['time']:.4f} " f"data_time: {train_logs['data_time']:.4f} ") - if torch.cuda.is_available(): + if is_cuda_available() or is_musa_available(): log_str += 'memory: 100 ' if mode == 'train': @@ -249,6 +250,7 @@ def test_collect_non_scalars(self): assert tag['metric1'] is metric1 assert tag['metric2'] is metric2 + # TODO:haowen.han@mtheads.com MUSA does not support it yet! @patch('torch.cuda.max_memory_allocated', MagicMock()) @patch('torch.cuda.reset_peak_memory_stats', MagicMock()) def test_get_max_memory(self):