Skip to content
This repository was archived by the owner on Mar 20, 2026. It is now read-only.

Commit 1e324a5

Browse files
Myle Ottfacebook-github-bot
authored andcommitted
Add metrics.py
Summary: Pull Request resolved: fairinternal/fairseq-py#973 Differential Revision: D19266024 Pulled By: myleott fbshipit-source-id: 3e2f1b8d10a5ac5ee23183a34f2078ba521c905d
1 parent 9090dad commit 1e324a5

2 files changed

Lines changed: 402 additions & 22 deletions

File tree

fairseq/meters.py

Lines changed: 186 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,40 +3,108 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import bisect
7+
from collections import OrderedDict
68
import time
9+
from typing import Dict, Optional
710

811

9-
class AverageMeter(object):
10-
"""Computes and stores the average and current value"""
12+
class Meter(object):
13+
"""Base class for Meters."""
14+
1115
def __init__(self):
16+
pass
17+
18+
def state_dict(self):
19+
return {}
20+
21+
def load_state_dict(self, state_dict):
22+
pass
23+
24+
def reset(self):
25+
raise NotImplementedError
26+
27+
@property
28+
def smoothed_value(self) -> float:
29+
"""Smoothed value used for logging."""
30+
raise NotImplementedError
31+
32+
33+
class AverageMeter(Meter):
34+
"""Computes and stores the average and current value"""
35+
36+
def __init__(self, round: Optional[int] = None):
37+
self.round = round
1238
self.reset()
1339

1440
def reset(self):
15-
self.val = 0
16-
self.avg = 0
17-
self.sum = 0
18-
self.count = 0
41+
self.val = None # most recent update
42+
self.sum = 0 # sum from all updates
43+
self.count = 0 # total n from all updates
1944

2045
def update(self, val, n=1):
21-
self.val = val
22-
self.sum += val * n
23-
self.count += n
24-
self.avg = self.sum / self.count
46+
if val is not None:
47+
self.val = val
48+
if n > 0:
49+
self.sum += val * n
50+
self.count += n
51+
52+
def state_dict(self):
53+
return {
54+
'val': self.val,
55+
'sum': self.sum,
56+
'count': self.count,
57+
'round': self.round,
58+
}
59+
60+
def load_state_dict(self, state_dict):
61+
self.val = state_dict['val']
62+
self.sum = state_dict['sum']
63+
self.count = state_dict['count']
64+
self.round = state_dict.get('round', None)
65+
66+
@property
67+
def avg(self):
68+
return self.sum / self.count if self.count > 0 else self.val
69+
70+
@property
71+
def smoothed_value(self) -> float:
72+
val = self.avg
73+
if self.round is not None and val is not None:
74+
val = round(val, self.round)
75+
return val
2576

2677

27-
class TimeMeter(object):
78+
class TimeMeter(Meter):
2879
"""Computes the average occurrence of some event per second"""
29-
def __init__(self, init=0):
30-
self.reset(init)
3180

32-
def reset(self, init=0):
81+
def __init__(self, init: int = 0, n: int = 0, round: Optional[int] = None):
82+
self.round = round
83+
self.reset(init, n)
84+
85+
def reset(self, init=0, n=0):
3386
self.init = init
3487
self.start = time.time()
35-
self.n = 0
88+
self.n = n
3689

3790
def update(self, val=1):
3891
self.n += val
3992

93+
def state_dict(self):
94+
return {
95+
'init': self.elapsed_time,
96+
'n': self.n,
97+
'round': self.round,
98+
}
99+
100+
def load_state_dict(self, state_dict):
101+
if 'start' in state_dict:
102+
# backwards compatibility for old state_dicts
103+
self.reset(init=state_dict['init'])
104+
else:
105+
self.reset(init=state_dict['init'], n=state_dict['n'])
106+
self.round = state_dict.get('round', None)
107+
40108
@property
41109
def avg(self):
42110
return self.n / self.elapsed_time
@@ -45,11 +113,22 @@ def avg(self):
45113
def elapsed_time(self):
46114
return self.init + (time.time() - self.start)
47115

116+
@property
117+
def smoothed_value(self) -> float:
118+
val = self.avg
119+
if self.round is not None and val is not None:
120+
val = round(val, self.round)
121+
return val
122+
48123

49-
class StopwatchMeter(object):
124+
class StopwatchMeter(Meter):
50125
"""Computes the sum/avg duration of some event in seconds"""
51-
def __init__(self):
52-
self.reset()
126+
127+
def __init__(self, round: Optional[int] = None):
128+
self.round = round
129+
self.sum = 0
130+
self.n = 0
131+
self.start_time = None
53132

54133
def start(self):
55134
self.start_time = time.time()
@@ -59,13 +138,98 @@ def stop(self, n=1):
59138
delta = time.time() - self.start_time
60139
self.sum += delta
61140
self.n += n
62-
self.start_time = None
63141

64142
def reset(self):
65-
self.sum = 0
66-
self.n = 0
143+
self.sum = 0 # cumulative time during which stopwatch was active
144+
self.n = 0 # total n across all start/stop
145+
self.start()
146+
147+
def state_dict(self):
148+
return {
149+
'sum': self.sum,
150+
'n': self.n,
151+
'round': self.round,
152+
}
153+
154+
def load_state_dict(self, state_dict):
155+
self.sum = state_dict['sum']
156+
self.n = state_dict['n']
67157
self.start_time = None
158+
self.round = state_dict.get('round', None)
68159

69160
@property
70161
def avg(self):
71-
return self.sum / self.n
162+
return self.sum / self.n if self.n > 0 else self.sum
163+
164+
@property
165+
def elapsed_time(self):
166+
if self.start_time is None:
167+
return 0.
168+
return time.time() - self.start_time
169+
170+
@property
171+
def smoothed_value(self) -> float:
172+
val = self.avg if self.sum > 0 else self.elapsed_time
173+
if self.round is not None and val is not None:
174+
val = round(val, self.round)
175+
return val
176+
177+
178+
class MetersDict(OrderedDict):
179+
"""A sorted dictionary of :class:`Meters`.
180+
181+
Meters are sorted according to a priority that is given when the
182+
meter is first added to the dictionary.
183+
"""
184+
185+
def __init__(self, *args, **kwargs):
186+
super().__init__(*args, **kwargs)
187+
self.priorities = []
188+
189+
def __setitem__(self, key, value):
190+
assert key not in self, "MetersDict doesn't support reassignment"
191+
priority, value = value
192+
bisect.insort(self.priorities, (priority, len(self.priorities), key))
193+
super().__setitem__(key, value)
194+
for _, _, key in self.priorities: # reorder dict to match priorities
195+
self.move_to_end(key)
196+
197+
def add_meter(self, key, meter, priority):
198+
self.__setitem__(key, (priority, meter))
199+
200+
def state_dict(self):
201+
return [
202+
(pri, key, self[key].__class__.__name__, self[key].state_dict())
203+
for pri, _, key in self.priorities
204+
# can't serialize DerivedMeter instances
205+
if not isinstance(self[key], MetersDict._DerivedMeter)
206+
]
207+
208+
def load_state_dict(self, state_dict):
209+
self.clear()
210+
self.priorities.clear()
211+
for pri, key, meter_cls, meter_state in state_dict:
212+
meter = globals()[meter_cls]()
213+
meter.load_state_dict(meter_state)
214+
self.add_meter(key, meter, pri)
215+
216+
def get_smoothed_value(self, key: str) -> float:
217+
"""Get a single smoothed value."""
218+
meter = self[key]
219+
if isinstance(meter, MetersDict._DerivedMeter):
220+
return meter.fn(self)
221+
else:
222+
return meter.smoothed_value
223+
224+
def get_smoothed_values(self) -> Dict[str, float]:
225+
"""Get all smoothed values."""
226+
return OrderedDict([(key, self.get_smoothed_value(key)) for key in self.keys()])
227+
228+
class _DerivedMeter(Meter):
229+
"""A Meter whose values are derived from other Meters."""
230+
231+
def __init__(self, fn):
232+
self.fn = fn
233+
234+
def reset(self):
235+
pass

0 commit comments

Comments
 (0)