33# This source code is licensed under the MIT license found in the
44# LICENSE file in the root directory of this source tree.
55
6+ import bisect
7+ from collections import OrderedDict
68import time
9+ from typing import Dict , Optional
710
811
9- class AverageMeter (object ):
10- """Computes and stores the average and current value"""
12+ class Meter (object ):
13+ """Base class for Meters."""
14+
1115 def __init__ (self ):
16+ pass
17+
18+ def state_dict (self ):
19+ return {}
20+
21+ def load_state_dict (self , state_dict ):
22+ pass
23+
24+ def reset (self ):
25+ raise NotImplementedError
26+
27+ @property
28+ def smoothed_value (self ) -> float :
29+ """Smoothed value used for logging."""
30+ raise NotImplementedError
31+
32+
33+ class AverageMeter (Meter ):
34+ """Computes and stores the average and current value"""
35+
36+ def __init__ (self , round : Optional [int ] = None ):
37+ self .round = round
1238 self .reset ()
1339
1440 def reset (self ):
15- self .val = 0
16- self .avg = 0
17- self .sum = 0
18- self .count = 0
41+ self .val = None # most recent update
42+ self .sum = 0 # sum from all updates
43+ self .count = 0 # total n from all updates
1944
2045 def update (self , val , n = 1 ):
21- self .val = val
22- self .sum += val * n
23- self .count += n
24- self .avg = self .sum / self .count
46+ if val is not None :
47+ self .val = val
48+ if n > 0 :
49+ self .sum += val * n
50+ self .count += n
51+
52+ def state_dict (self ):
53+ return {
54+ 'val' : self .val ,
55+ 'sum' : self .sum ,
56+ 'count' : self .count ,
57+ 'round' : self .round ,
58+ }
59+
60+ def load_state_dict (self , state_dict ):
61+ self .val = state_dict ['val' ]
62+ self .sum = state_dict ['sum' ]
63+ self .count = state_dict ['count' ]
64+ self .round = state_dict .get ('round' , None )
65+
66+ @property
67+ def avg (self ):
68+ return self .sum / self .count if self .count > 0 else self .val
69+
70+ @property
71+ def smoothed_value (self ) -> float :
72+ val = self .avg
73+ if self .round is not None and val is not None :
74+ val = round (val , self .round )
75+ return val
2576
2677
27- class TimeMeter (object ):
78+ class TimeMeter (Meter ):
2879 """Computes the average occurrence of some event per second"""
29- def __init__ (self , init = 0 ):
30- self .reset (init )
3180
32- def reset (self , init = 0 ):
81+ def __init__ (self , init : int = 0 , n : int = 0 , round : Optional [int ] = None ):
82+ self .round = round
83+ self .reset (init , n )
84+
85+ def reset (self , init = 0 , n = 0 ):
3386 self .init = init
3487 self .start = time .time ()
35- self .n = 0
88+ self .n = n
3689
3790 def update (self , val = 1 ):
3891 self .n += val
3992
93+ def state_dict (self ):
94+ return {
95+ 'init' : self .elapsed_time ,
96+ 'n' : self .n ,
97+ 'round' : self .round ,
98+ }
99+
100+ def load_state_dict (self , state_dict ):
101+ if 'start' in state_dict :
102+ # backwards compatibility for old state_dicts
103+ self .reset (init = state_dict ['init' ])
104+ else :
105+ self .reset (init = state_dict ['init' ], n = state_dict ['n' ])
106+ self .round = state_dict .get ('round' , None )
107+
40108 @property
41109 def avg (self ):
42110 return self .n / self .elapsed_time
@@ -45,11 +113,22 @@ def avg(self):
45113 def elapsed_time (self ):
46114 return self .init + (time .time () - self .start )
47115
116+ @property
117+ def smoothed_value (self ) -> float :
118+ val = self .avg
119+ if self .round is not None and val is not None :
120+ val = round (val , self .round )
121+ return val
122+
48123
49- class StopwatchMeter (object ):
124+ class StopwatchMeter (Meter ):
50125 """Computes the sum/avg duration of some event in seconds"""
51- def __init__ (self ):
52- self .reset ()
126+
127+ def __init__ (self , round : Optional [int ] = None ):
128+ self .round = round
129+ self .sum = 0
130+ self .n = 0
131+ self .start_time = None
53132
54133 def start (self ):
55134 self .start_time = time .time ()
@@ -59,13 +138,98 @@ def stop(self, n=1):
59138 delta = time .time () - self .start_time
60139 self .sum += delta
61140 self .n += n
62- self .start_time = None
63141
64142 def reset (self ):
65- self .sum = 0
66- self .n = 0
143+ self .sum = 0 # cumulative time during which stopwatch was active
144+ self .n = 0 # total n across all start/stop
145+ self .start ()
146+
147+ def state_dict (self ):
148+ return {
149+ 'sum' : self .sum ,
150+ 'n' : self .n ,
151+ 'round' : self .round ,
152+ }
153+
154+ def load_state_dict (self , state_dict ):
155+ self .sum = state_dict ['sum' ]
156+ self .n = state_dict ['n' ]
67157 self .start_time = None
158+ self .round = state_dict .get ('round' , None )
68159
69160 @property
70161 def avg (self ):
71- return self .sum / self .n
162+ return self .sum / self .n if self .n > 0 else self .sum
163+
164+ @property
165+ def elapsed_time (self ):
166+ if self .start_time is None :
167+ return 0.
168+ return time .time () - self .start_time
169+
170+ @property
171+ def smoothed_value (self ) -> float :
172+ val = self .avg if self .sum > 0 else self .elapsed_time
173+ if self .round is not None and val is not None :
174+ val = round (val , self .round )
175+ return val
176+
177+
178+ class MetersDict (OrderedDict ):
179+ """A sorted dictionary of :class:`Meters`.
180+
181+ Meters are sorted according to a priority that is given when the
182+ meter is first added to the dictionary.
183+ """
184+
185+ def __init__ (self , * args , ** kwargs ):
186+ super ().__init__ (* args , ** kwargs )
187+ self .priorities = []
188+
189+ def __setitem__ (self , key , value ):
190+ assert key not in self , "MetersDict doesn't support reassignment"
191+ priority , value = value
192+ bisect .insort (self .priorities , (priority , len (self .priorities ), key ))
193+ super ().__setitem__ (key , value )
194+ for _ , _ , key in self .priorities : # reorder dict to match priorities
195+ self .move_to_end (key )
196+
197+ def add_meter (self , key , meter , priority ):
198+ self .__setitem__ (key , (priority , meter ))
199+
200+ def state_dict (self ):
201+ return [
202+ (pri , key , self [key ].__class__ .__name__ , self [key ].state_dict ())
203+ for pri , _ , key in self .priorities
204+ # can't serialize DerivedMeter instances
205+ if not isinstance (self [key ], MetersDict ._DerivedMeter )
206+ ]
207+
208+ def load_state_dict (self , state_dict ):
209+ self .clear ()
210+ self .priorities .clear ()
211+ for pri , key , meter_cls , meter_state in state_dict :
212+ meter = globals ()[meter_cls ]()
213+ meter .load_state_dict (meter_state )
214+ self .add_meter (key , meter , pri )
215+
216+ def get_smoothed_value (self , key : str ) -> float :
217+ """Get a single smoothed value."""
218+ meter = self [key ]
219+ if isinstance (meter , MetersDict ._DerivedMeter ):
220+ return meter .fn (self )
221+ else :
222+ return meter .smoothed_value
223+
224+ def get_smoothed_values (self ) -> Dict [str , float ]:
225+ """Get all smoothed values."""
226+ return OrderedDict ([(key , self .get_smoothed_value (key )) for key in self .keys ()])
227+
228+ class _DerivedMeter (Meter ):
229+ """A Meter whose values are derived from other Meters."""
230+
231+ def __init__ (self , fn ):
232+ self .fn = fn
233+
234+ def reset (self ):
235+ pass
0 commit comments