1515
1616import io
1717import sys
18- import time
19- import tracemalloc
2018from unittest import mock
2119
22- import absl .testing .absltest as absltest
20+ from absl .testing import absltest
21+ from absl .testing import parameterized
22+ import tqdm
23+
2324from ai_edge_quantizer .utils import progress_utils
2425
2526
2627class ProgressBarTest (absltest .TestCase ):
2728
28- @mock .patch ('tqdm.tqdm' )
29- def test_progress_bar_update (self , mock_tqdm ):
30- mock_progress_bar_instance = mock_tqdm .return_value
29+ def setUp (self ):
30+ super ().setUp ()
31+ self .mock_tqdm = self .enter_context (
32+ mock .patch .object (tqdm , 'tqdm' , autospec = True , spec_set = True )
33+ )
34+
35+ def test_progress_bar_update (self ):
36+ mock_progress_bar_instance = self .mock_tqdm .return_value
3137 with progress_utils .ProgressBar (total_steps = 10 ) as pb :
3238 pb .update_single_step ()
3339 pb .update_single_step ()
3440
35- mock_tqdm .assert_called_once_with (
41+ self . mock_tqdm .assert_called_once_with (
3642 total = 10 , desc = '' , leave = True , disable = False
3743 )
3844 self .assertEqual (mock_progress_bar_instance .update .call_count , 2 )
3945 mock_progress_bar_instance .update .assert_called_with (1 )
4046 mock_progress_bar_instance .close .assert_called_once ()
4147
42- @mock .patch ('tqdm.tqdm' )
43- def test_progress_bar_disable (self , mock_tqdm ):
44- mock_progress_bar_instance = mock_tqdm .return_value
48+ def test_progress_bar_disable (self ):
49+ mock_progress_bar_instance = self .mock_tqdm .return_value
4550 with progress_utils .ProgressBar (total_steps = 10 , disable = True ):
4651 pass
47- mock_tqdm .assert_called_once_with (
52+ self . mock_tqdm .assert_called_once_with (
4853 total = 10 , desc = '' , leave = True , disable = True
4954 )
5055 mock_progress_bar_instance .close .assert_called_once ()
5156
52- @mock .patch ('tqdm.tqdm' )
53- def test_progress_bar_disappear_on_finish (self , mock_tqdm ):
54- mock_progress_bar_instance = mock_tqdm .return_value
57+ def test_progress_bar_disappear_on_finish (self ):
58+ mock_progress_bar_instance = self .mock_tqdm .return_value
5559 with progress_utils .ProgressBar (total_steps = 10 , disappear_on_finish = True ):
5660 pass
57- mock_tqdm .assert_called_once_with (
61+ self . mock_tqdm .assert_called_once_with (
5862 total = 10 , desc = '' , leave = False , disable = False
5963 )
6064 mock_progress_bar_instance .close .assert_called_once ()
6165
6266
63- class ProgressReportTest (absltest .TestCase ):
67+ class ProgressReportTest (parameterized .TestCase ):
6468
6569 def setUp (self ):
6670 super ().setUp ()
6771 self .mock_time = self .enter_context (
68- mock .patch .object (time , 'time' , autospec = True )
72+ mock .patch .object (progress_utils , 'time' , autospec = True , spec_set = True )
6973 )
70- self .mock_tracemalloc_start = self .enter_context (
71- mock .patch .object (tracemalloc , 'start' , autospec = True )
72- )
73- self .mock_tracemalloc_get_traced_memory = self .enter_context (
74- mock .patch .object (tracemalloc , 'get_traced_memory' , autospec = True )
75- )
76-
77- def test_generate_progress_report (self ):
78- self .mock_time .side_effect = [100.0 , 105.5 ] # Start time, end time.
79- # Mock memory: current=1MB, peak=2MB.
80- self .mock_tracemalloc_get_traced_memory .return_value = (
81- 1 * 1024 * 1024 ,
82- 2 * 1024 * 1024 ,
74+ self .mock_tracemalloc = self .enter_context (
75+ mock .patch .object (
76+ progress_utils , 'tracemalloc' , autospec = True , spec_set = True
77+ )
8378 )
8479
85- progress_report = progress_utils .ProgressReport ()
80+ @parameterized .named_parameters (
81+ ('trace_memory_enabled' , True ), ('trace_memory_disabled' , False )
82+ )
83+ def test_generate_progress_report (self , trace_memory : bool ):
84+ self .mock_time .time .side_effect = [100.0 , 105.5 ] # Start time, end time.
85+
86+ if trace_memory :
87+ self .mock_tracemalloc .is_tracing .return_value = False
88+ self .mock_tracemalloc .start .side_effect = None
89+ self .mock_tracemalloc .stop .return_value = None
90+ self .mock_tracemalloc .get_traced_memory .return_value = (
91+ 1 * 1024 * 1024 ,
92+ 2 * 1024 * 1024 ,
93+ )
94+
95+ progress_report = progress_utils .ProgressReport (trace_memory = trace_memory )
8696 progress_report .capture_progess_start ()
8797
8898 original_model = b'\x01 ' * 2048 # 2KB.
@@ -92,13 +102,17 @@ def test_generate_progress_report(self):
92102 with mock .patch .object (sys , 'stdout' , mock_stdout ):
93103 progress_report .generate_progress_report (original_model , quantized_model )
94104
95- self .mock_tracemalloc_start .assert_called_once ()
96105 output = mock_stdout .getvalue ()
97106 self .assertIn ('Original model size: 2.00 KB' , output )
98107 self .assertIn ('Quantized model size: 1.00 KB' , output )
99108 self .assertIn ('Quantization Ratio: 0.50' , output )
100109 self .assertIn ('Total time: 5.50 seconds' , output )
101- self .assertIn ('Memory peak: 2.00 MB' , output )
110+
111+ if trace_memory :
112+ self .mock_tracemalloc .is_tracing .assert_called_once_with ()
113+ self .mock_tracemalloc .start .assert_called_once_with ()
114+ self .mock_tracemalloc .stop .assert_called_once_with ()
115+ self .assertIn ('Memory peak: 2.00 MB' , output )
102116
103117
104118if __name__ == '__main__' :
0 commit comments