13
13
from parlai .core .thread_utils import SharedTable
14
14
from parlai .core .utils import round_sigfigs , no_lock
15
15
from collections import Counter
16
+ from parlai .core .utils import warn_once
16
17
17
18
import re
18
19
23
24
# We'll just turn off things, but we might want to warn the user
24
25
nltkbleu = None
25
26
27
+ try :
28
+ import rouge as rouge
29
+ except ImportError :
30
+ # User doesn't have rouge installed, so we can't use it for rouge
31
+ # We'll just turn off things, but we might want to warn the user
32
+ warn_once ('Rouge metrics require py-rouge. Please run `pip install py-rouge`.' )
33
+ rouge = None
34
+
26
35
re_art = re .compile (r'\b(a|an|the)\b' )
27
36
re_punc = re .compile (r'[!"#$%&()*+,-./:;<=>?@\[\]\\^`{|}~_\']' )
28
37
@@ -103,6 +112,29 @@ def _bleu(guess, answers):
103
112
)
104
113
105
114
115
+ def _rouge (guess , answers ):
116
+ global rouge
117
+ """Compute ROUGE score between guess and *any* answers. Return the best."""
118
+ if rouge is None :
119
+ return None , None , None
120
+ evaluator = rouge .Rouge (metrics = ['rouge-n' , 'rouge-l' ], max_n = 2 )
121
+ try :
122
+ scores = [evaluator .get_scores (normalize_answer (guess ), normalize_answer (a ))
123
+ for a in answers ]
124
+ except LookupError :
125
+ warn_once (
126
+ 'ROUGE requires nltk punkt tokenizer. Please run '
127
+ '`python -c "import nltk; nltk.download(\' punkt\' )`'
128
+ )
129
+ rouge = None
130
+ return None , None , None
131
+
132
+ scores_rouge1 = [score ['rouge-1' ]['r' ] for score in scores ]
133
+ scores_rouge2 = [score ['rouge-2' ]['r' ] for score in scores ]
134
+ scores_rougel = [score ['rouge-l' ]['r' ] for score in scores ]
135
+ return max (scores_rouge1 ), max (scores_rouge2 ), max (scores_rougel )
136
+
137
+
106
138
def aggregate_metrics (reporters ):
107
139
"""Aggregate metrics from multiple reports."""
108
140
# reporters is a list of teachers or worlds
@@ -111,6 +143,10 @@ def aggregate_metrics(reporters):
111
143
sums = {'accuracy' : 0 , 'f1' : 0 , 'loss' : 0 , 'ppl' : 0 }
112
144
if nltkbleu is not None :
113
145
sums ['bleu' ] = 0
146
+ if rouge is not None :
147
+ sums ['rouge-1' ] = 0.0
148
+ sums ['rouge-2' ] = 0.0
149
+ sums ['rouge-L' ] = 0.0
114
150
num_tasks = 0
115
151
total = 0
116
152
for i in range (len (reporters )):
@@ -146,6 +182,11 @@ def __init__(self, opt):
146
182
if nltkbleu is not None :
147
183
# only compute bleu if we can
148
184
self .metrics_list .append ('bleu' )
185
+ if rouge is not None :
186
+ # only compute rouge if we can
187
+ self .metrics_list .append ('rouge-1' )
188
+ self .metrics_list .append ('rouge-2' )
189
+ self .metrics_list .append ('rouge-L' )
149
190
for k in self .metrics_list :
150
191
self .metrics [k ] = 0.0
151
192
self .metrics [k + '_cnt' ] = 0
@@ -219,20 +260,30 @@ def update(self, observation, labels):
219
260
# F1 and BLEU metrics.
220
261
f1 = _f1_score (prediction , labels )
221
262
bleu = _bleu (prediction , labels )
263
+ rouge1 , rouge2 , rougel = _rouge (prediction , labels )
264
+
222
265
with self ._lock ():
223
266
self .metrics ['f1' ] += f1
224
267
self .metrics ['f1_cnt' ] += 1
225
268
if bleu is not None :
226
269
self .metrics ['bleu' ] += bleu
227
270
self .metrics ['bleu_cnt' ] += 1
271
+ if rouge1 is not None :
272
+ self .metrics ['rouge-1' ] += rouge1
273
+ self .metrics ['rouge-2' ] += rouge2
274
+ self .metrics ['rouge-L' ] += rougel
275
+ self .metrics ['rouge-1_cnt' ] += 1
276
+ self .metrics ['rouge-2_cnt' ] += 1
277
+ self .metrics ['rouge-L_cnt' ] += 1
228
278
229
279
# Ranking metrics.
230
280
self ._update_ranking_metrics (observation , labels )
231
281
232
282
# User-reported metrics
233
283
if 'metrics' in observation :
234
284
for k , v in observation ['metrics' ].items ():
235
- if k not in ['correct' , 'f1' , 'hits@k' , 'bleu' ]:
285
+ if k not in ['correct' , 'f1' , 'hits@k' , 'bleu' , 'rouge-1' ,
286
+ 'rouge-2' , 'rouge-L' ]:
236
287
if k in self .metrics_list :
237
288
with self ._lock ():
238
289
self .metrics [k ] += v
0 commit comments