Skip to content

Commit f7543e2

Browse files
committed
revise logging/logger.py for ci
1 parent 243d093 commit f7543e2

File tree

4 files changed

+32
-39
lines changed

4 files changed

+32
-39
lines changed

mmengine/hooks/empty_cache_hook.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55

66
from mmengine.registry import HOOKS
7-
from ..device import is_cuda_available, is_musa_available
7+
from ..device import is_musa_available
88
from .hook import Hook
99

1010
DATA_BATCH = Optional[Union[dict, tuple, list]]
@@ -50,9 +50,8 @@ def _after_iter(self,
5050
mode (str): Current mode of runner. Defaults to 'train'.
5151
"""
5252
if self._do_after_iter:
53-
if is_cuda_available():
54-
torch.cuda.empty_cache()
55-
elif is_musa_available():
53+
torch.cuda.empty_cache()
54+
if is_musa_available():
5655
torch.musa.empty_cache()
5756

5857
def _before_epoch(self, runner, mode: str = 'train') -> None:
@@ -63,9 +62,8 @@ def _before_epoch(self, runner, mode: str = 'train') -> None:
6362
mode (str): Current mode of runner. Defaults to 'train'.
6463
"""
6564
if self._do_before_epoch:
66-
if is_cuda_available():
67-
torch.cuda.empty_cache()
68-
elif is_musa_available():
65+
torch.cuda.empty_cache()
66+
if is_musa_available():
6967
torch.musa.empty_cache()
7068

7169
def _after_epoch(self, runner, mode: str = 'train') -> None:
@@ -76,7 +74,6 @@ def _after_epoch(self, runner, mode: str = 'train') -> None:
7674
mode (str): Current mode of runner. Defaults to 'train'.
7775
"""
7876
if self._do_after_epoch:
79-
if is_cuda_available():
80-
torch.cuda.empty_cache()
81-
elif is_musa_available():
77+
torch.cuda.empty_cache()
78+
if is_musa_available():
8279
torch.musa.empty_cache()

mmengine/logging/logger.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
from mmengine.utils import ManagerMixin
1616
from mmengine.utils.manager import _accquire_lock, _release_lock
17-
from ..device import is_cuda_available, is_musa_available
1817

1918

2019
class FilterDuplicateWarning(logging.Filter):
@@ -399,24 +398,17 @@ def _get_device_id():
399398
except ImportError:
400399
return 0
401400
else:
402-
local_rank = int(os.getenv('LOCAL_RANK', '0'))
403-
if is_cuda_available():
404-
cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', None)
405-
if cuda_visible_devices is None:
406-
num_device = torch.cuda.device_count()
407-
cuda_visible_devices = list(range(num_device))
408-
else:
409-
cuda_visible_devices = cuda_visible_devices.split(',')
410-
try:
411-
return int(cuda_visible_devices[local_rank])
412-
except ValueError:
413-
# handle case for Multi-Instance GPUs
414-
# see #1148 for details
415-
return cuda_visible_devices[local_rank]
416-
elif is_musa_available():
401+
MUSA_AVAILABLE = False
402+
try:
403+
import torch_musa
404+
MUSA_AVAILABLE = True
405+
except ImportError:
406+
pass
407+
if MUSA_AVAILABLE:
408+
local_rank = int(os.getenv('LOCAL_RANK', '0'))
417409
musa_visible_devices = os.getenv('MUSA_VISIBLE_DEVICES', None)
418410
if musa_visible_devices is None:
419-
num_device = torch.musa.device_count()
411+
num_device = torch_musa.device_count()
420412
musa_visible_devices = list(range(num_device))
421413
else:
422414
musa_visible_devices = musa_visible_devices.split(',')
@@ -427,8 +419,22 @@ def _get_device_id():
427419
# see #1148 for details
428420
return musa_visible_devices[local_rank]
429421
else:
422+
local_rank = int(os.getenv('LOCAL_RANK', '0'))
430423
# TODO: return device id of npu and mlu.
431-
return local_rank
424+
if not torch.cuda.is_available():
425+
return local_rank
426+
cuda_visible_devices = os.getenv('CUDA_VISIBLE_DEVICES', None)
427+
if cuda_visible_devices is None:
428+
num_device = torch.cuda.device_count()
429+
cuda_visible_devices = list(range(num_device))
430+
else:
431+
cuda_visible_devices = cuda_visible_devices.split(',')
432+
try:
433+
return int(cuda_visible_devices[local_rank])
434+
except ValueError:
435+
# handle case for Multi-Instance GPUs
436+
# see #1148 for details
437+
return cuda_visible_devices[local_rank]
432438

433439

434440
def _get_host_info() -> str:

tests/test_hooks/test_empty_cache_hook.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,9 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2-
import unittest
32
from unittest.mock import patch
43

5-
from mmengine.device import is_musa_available
64
from mmengine.testing import RunnerTestCase
75

86

9-
10-
@unittest.skipIf(
11-
is_musa_available(),
12-
'torch_musa do not support torch.musa.reset_peak_memory_stats() yet')
137
class TestEmptyCacheHook(RunnerTestCase):
148

159
def test_with_runner(self):

tests/test_runner/test_log_processor.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import copy
3-
import unittest
43
from unittest.mock import MagicMock, patch
54

65
import numpy as np
@@ -251,10 +250,7 @@ def test_collect_non_scalars(self):
251250
assert tag['metric1'] is metric1
252251
assert tag['metric2'] is metric2
253252

254-
255-
@unittest.skipIf(
256-
is_musa_available(),
257-
'musa backend do not support torch.cuda.reset_peak_memory_stats')
253+
# TODO:[email protected] MUSA does not support it yet!
258254
@patch('torch.cuda.max_memory_allocated', MagicMock())
259255
@patch('torch.cuda.reset_peak_memory_stats', MagicMock())
260256
def test_get_max_memory(self):

0 commit comments

Comments
 (0)