Skip to content

Commit 38864d3

Browse files
gonnetcopybara-github
authored andcommitted
Fix progress_utils_test.py to restrict mocking of system-wide modules to the module under test.
This test was failing because sometimes other modules were calling `time.time` or `tracemalloc.get_traced_memory`, stealing the mocked return value. Also add a warning that running with `tracemalloc` can significantly slow down the code. PiperOrigin-RevId: 895884076
1 parent 2d837dd commit 38864d3

File tree

2 files changed

+83
-42
lines changed

2 files changed

+83
-42
lines changed

ai_edge_quantizer/utils/progress_utils.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
"""Utility functions to display a progress bar and progress report."""
1717

18+
import logging
1819
import time
1920
import tracemalloc
2021
import tqdm
@@ -55,42 +56,68 @@ def close(self):
5556

5657

5758
class ProgressReport:
58-
"""A class to generate a progress report for the quantization process."""
59+
"""A class to generate a progress report for the quantization process.
60+
61+
If initialized with `trace_memory=True`, it will also track the peak memory
62+
use using `tracemalloc`, which may hurt performance and interfere with other
63+
code using `tracemalloc` concurrently.
64+
"""
65+
_description: str
66+
_trace_memory: bool
67+
_start_time: float | None = None
68+
_tracemalloc_started_by_me: bool = False
5969

60-
def __init__(self, description: str = ''):
70+
def __init__(self, description: str = '', trace_memory: bool = False):
6171
self._description = description
62-
self._start_time = None
72+
self._trace_memory = trace_memory
6373

6474
def capture_progess_start(self):
6575
self._start_time = time.time()
66-
tracemalloc.start()
76+
if self._trace_memory:
77+
logging.warning(
78+
'Progress bar reporting with `trace_memory=True` is enabled which may'
79+
' significantly slow down your computations!'
80+
)
81+
self._tracemalloc_started_by_me = not tracemalloc.is_tracing()
82+
if self._tracemalloc_started_by_me:
83+
tracemalloc.start()
84+
else:
85+
tracemalloc.reset_peak()
86+
87+
def _capture_progress_end(self) -> int | None:
88+
if self._trace_memory:
89+
_, mem_peak_bytes = tracemalloc.get_traced_memory()
90+
if self._tracemalloc_started_by_me:
91+
tracemalloc.stop()
92+
self._tracemalloc_started_by_me = False
93+
return mem_peak_bytes
6794

6895
def render_report(
6996
self,
7097
original_size: int,
7198
quantized_size: int,
7299
quantization_ratio: float,
73-
memory_peak: float,
74100
total_time: float,
101+
memory_peak: float | None,
75102
):
76103
"""Prints out the progress report."""
77104
print(f'Original model size: {original_size/1024:.2f} KB')
78105
print(f'Quantized model size: {quantized_size/1024:.2f} KB')
79106
print(f'Quantization Ratio: {quantization_ratio:.2f}')
80107
print(f'Total time: {total_time:.2f} seconds')
81-
print(f'Memory peak: {memory_peak:.2f} MB')
108+
if memory_peak is not None:
109+
print(f'Memory peak: {memory_peak/1024/1024:.2f} MB')
82110

83111
def generate_progress_report(self, original_model, quantized_model):
84112
original_size = len(original_model)
85113
quantized_size = len(quantized_model)
86114
quantization_ratio = quantized_size / original_size
87115
total_time = time.time() - self._start_time
88-
_, mem_peak_bytes = tracemalloc.get_traced_memory()
89-
mem_peak_mb = mem_peak_bytes / 1024 / 1024
116+
mem_peak_bytes = self._capture_progress_end()
90117
self.render_report(
91118
original_size,
92119
quantized_size,
93120
quantization_ratio,
94-
mem_peak_mb,
95121
total_time,
122+
mem_peak_bytes,
96123
)

ai_edge_quantizer/utils/progress_utils_test.py

Lines changed: 47 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -15,74 +15,84 @@
1515

1616
import io
1717
import sys
18-
import time
19-
import tracemalloc
2018
from unittest import mock
2119

22-
import absl.testing.absltest as absltest
20+
from absl.testing import absltest
21+
from absl.testing import parameterized
22+
import tqdm
23+
2324
from ai_edge_quantizer.utils import progress_utils
2425

2526

2627
class ProgressBarTest(absltest.TestCase):
2728

28-
@mock.patch('tqdm.tqdm')
29-
def test_progress_bar_update(self, mock_tqdm):
30-
mock_progress_bar_instance = mock_tqdm.return_value
29+
def setUp(self):
30+
super().setUp()
31+
self.mock_tqdm = self.enter_context(
32+
mock.patch.object(tqdm, 'tqdm', autospec=True, spec_set=True)
33+
)
34+
35+
def test_progress_bar_update(self):
36+
mock_progress_bar_instance = self.mock_tqdm.return_value
3137
with progress_utils.ProgressBar(total_steps=10) as pb:
3238
pb.update_single_step()
3339
pb.update_single_step()
3440

35-
mock_tqdm.assert_called_once_with(
41+
self.mock_tqdm.assert_called_once_with(
3642
total=10, desc='', leave=True, disable=False
3743
)
3844
self.assertEqual(mock_progress_bar_instance.update.call_count, 2)
3945
mock_progress_bar_instance.update.assert_called_with(1)
4046
mock_progress_bar_instance.close.assert_called_once()
4147

42-
@mock.patch('tqdm.tqdm')
43-
def test_progress_bar_disable(self, mock_tqdm):
44-
mock_progress_bar_instance = mock_tqdm.return_value
48+
def test_progress_bar_disable(self):
49+
mock_progress_bar_instance = self.mock_tqdm.return_value
4550
with progress_utils.ProgressBar(total_steps=10, disable=True):
4651
pass
47-
mock_tqdm.assert_called_once_with(
52+
self.mock_tqdm.assert_called_once_with(
4853
total=10, desc='', leave=True, disable=True
4954
)
5055
mock_progress_bar_instance.close.assert_called_once()
5156

52-
@mock.patch('tqdm.tqdm')
53-
def test_progress_bar_disappear_on_finish(self, mock_tqdm):
54-
mock_progress_bar_instance = mock_tqdm.return_value
57+
def test_progress_bar_disappear_on_finish(self):
58+
mock_progress_bar_instance = self.mock_tqdm.return_value
5559
with progress_utils.ProgressBar(total_steps=10, disappear_on_finish=True):
5660
pass
57-
mock_tqdm.assert_called_once_with(
61+
self.mock_tqdm.assert_called_once_with(
5862
total=10, desc='', leave=False, disable=False
5963
)
6064
mock_progress_bar_instance.close.assert_called_once()
6165

6266

63-
class ProgressReportTest(absltest.TestCase):
67+
class ProgressReportTest(parameterized.TestCase):
6468

6569
def setUp(self):
6670
super().setUp()
6771
self.mock_time = self.enter_context(
68-
mock.patch.object(time, 'time', autospec=True)
72+
mock.patch.object(progress_utils, 'time', autospec=True, spec_set=True)
6973
)
70-
self.mock_tracemalloc_start = self.enter_context(
71-
mock.patch.object(tracemalloc, 'start', autospec=True)
72-
)
73-
self.mock_tracemalloc_get_traced_memory = self.enter_context(
74-
mock.patch.object(tracemalloc, 'get_traced_memory', autospec=True)
75-
)
76-
77-
def test_generate_progress_report(self):
78-
self.mock_time.side_effect = [100.0, 105.5] # Start time, end time.
79-
# Mock memory: current=1MB, peak=2MB.
80-
self.mock_tracemalloc_get_traced_memory.return_value = (
81-
1 * 1024 * 1024,
82-
2 * 1024 * 1024,
74+
self.mock_tracemalloc = self.enter_context(
75+
mock.patch.object(
76+
progress_utils, 'tracemalloc', autospec=True, spec_set=True
77+
)
8378
)
8479

85-
progress_report = progress_utils.ProgressReport()
80+
@parameterized.named_parameters(
81+
('trace_memory_enabled', True), ('trace_memory_disabled', False)
82+
)
83+
def test_generate_progress_report(self, trace_memory: bool):
84+
self.mock_time.time.side_effect = [100.0, 105.5] # Start time, end time.
85+
86+
if trace_memory:
87+
self.mock_tracemalloc.is_tracing.return_value = False
88+
self.mock_tracemalloc.start.side_effect = None
89+
self.mock_tracemalloc.stop.return_value = None
90+
self.mock_tracemalloc.get_traced_memory.return_value = (
91+
1 * 1024 * 1024,
92+
2 * 1024 * 1024,
93+
)
94+
95+
progress_report = progress_utils.ProgressReport(trace_memory=trace_memory)
8696
progress_report.capture_progess_start()
8797

8898
original_model = b'\x01' * 2048 # 2KB.
@@ -92,13 +102,17 @@ def test_generate_progress_report(self):
92102
with mock.patch.object(sys, 'stdout', mock_stdout):
93103
progress_report.generate_progress_report(original_model, quantized_model)
94104

95-
self.mock_tracemalloc_start.assert_called_once()
96105
output = mock_stdout.getvalue()
97106
self.assertIn('Original model size: 2.00 KB', output)
98107
self.assertIn('Quantized model size: 1.00 KB', output)
99108
self.assertIn('Quantization Ratio: 0.50', output)
100109
self.assertIn('Total time: 5.50 seconds', output)
101-
self.assertIn('Memory peak: 2.00 MB', output)
110+
111+
if trace_memory:
112+
self.mock_tracemalloc.is_tracing.assert_called_once_with()
113+
self.mock_tracemalloc.start.assert_called_once_with()
114+
self.mock_tracemalloc.stop.assert_called_once_with()
115+
self.assertIn('Memory peak: 2.00 MB', output)
102116

103117

104118
if __name__ == '__main__':

0 commit comments

Comments
 (0)