Skip to content
This repository was archived by the owner on Nov 3, 2023. It is now read-only.

Commit 471db18

Browse files
authored
add rouge metrics (#1719)
* add rouge metrics by using py-rouge * Will warn if nltk tokenizer is not there * We add this tokenizer into our circleci so we can test rouge in CI.
1 parent e0cb245 commit 471db18

File tree

4 files changed

+62
-6
lines changed

4 files changed

+62
-6
lines changed

.circleci/config.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ installdeps: &installdeps
4242
name: Installs basic dependencies
4343
command: |
4444
python setup.py develop
45+
python -c "import nltk; nltk.download('punkt')"
4546
4647
installtorchgpu: &installtorchgpu
4748
run:

parlai/core/metrics.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from parlai.core.thread_utils import SharedTable
1414
from parlai.core.utils import round_sigfigs, no_lock
1515
from collections import Counter
16+
from parlai.core.utils import warn_once
1617

1718
import re
1819

@@ -23,6 +24,14 @@
2324
# We'll just turn off things, but we might want to warn the user
2425
nltkbleu = None
2526

27+
try:
28+
import rouge as rouge
29+
except ImportError:
30+
# User doesn't have rouge installed, so we can't use it for rouge
31+
# We'll just turn off things, but we might want to warn the user
32+
warn_once('Rouge metrics require py-rouge. Please run `pip install py-rouge`.')
33+
rouge = None
34+
2635
re_art = re.compile(r'\b(a|an|the)\b')
2736
re_punc = re.compile(r'[!"#$%&()*+,-./:;<=>?@\[\]\\^`{|}~_\']')
2837

@@ -103,6 +112,29 @@ def _bleu(guess, answers):
103112
)
104113

105114

115+
def _rouge(guess, answers):
116+
global rouge
117+
"""Compute ROUGE score between guess and *any* answers. Return the best."""
118+
if rouge is None:
119+
return None, None, None
120+
evaluator = rouge.Rouge(metrics=['rouge-n', 'rouge-l'], max_n=2)
121+
try:
122+
scores = [evaluator.get_scores(normalize_answer(guess), normalize_answer(a))
123+
for a in answers]
124+
except LookupError:
125+
warn_once(
126+
'ROUGE requires nltk punkt tokenizer. Please run '
127+
'`python -c "import nltk; nltk.download(\'punkt\')`'
128+
)
129+
rouge = None
130+
return None, None, None
131+
132+
scores_rouge1 = [score['rouge-1']['r'] for score in scores]
133+
scores_rouge2 = [score['rouge-2']['r'] for score in scores]
134+
scores_rougel = [score['rouge-l']['r'] for score in scores]
135+
return max(scores_rouge1), max(scores_rouge2), max(scores_rougel)
136+
137+
106138
def aggregate_metrics(reporters):
107139
"""Aggregate metrics from multiple reports."""
108140
# reporters is a list of teachers or worlds
@@ -111,6 +143,10 @@ def aggregate_metrics(reporters):
111143
sums = {'accuracy': 0, 'f1': 0, 'loss': 0, 'ppl': 0}
112144
if nltkbleu is not None:
113145
sums['bleu'] = 0
146+
if rouge is not None:
147+
sums['rouge-1'] = 0.0
148+
sums['rouge-2'] = 0.0
149+
sums['rouge-L'] = 0.0
114150
num_tasks = 0
115151
total = 0
116152
for i in range(len(reporters)):
@@ -146,6 +182,11 @@ def __init__(self, opt):
146182
if nltkbleu is not None:
147183
# only compute bleu if we can
148184
self.metrics_list.append('bleu')
185+
if rouge is not None:
186+
# only compute rouge if we can
187+
self.metrics_list.append('rouge-1')
188+
self.metrics_list.append('rouge-2')
189+
self.metrics_list.append('rouge-L')
149190
for k in self.metrics_list:
150191
self.metrics[k] = 0.0
151192
self.metrics[k + '_cnt'] = 0
@@ -219,20 +260,30 @@ def update(self, observation, labels):
219260
# F1 and BLEU metrics.
220261
f1 = _f1_score(prediction, labels)
221262
bleu = _bleu(prediction, labels)
263+
rouge1, rouge2, rougel = _rouge(prediction, labels)
264+
222265
with self._lock():
223266
self.metrics['f1'] += f1
224267
self.metrics['f1_cnt'] += 1
225268
if bleu is not None:
226269
self.metrics['bleu'] += bleu
227270
self.metrics['bleu_cnt'] += 1
271+
if rouge1 is not None:
272+
self.metrics['rouge-1'] += rouge1
273+
self.metrics['rouge-2'] += rouge2
274+
self.metrics['rouge-L'] += rougel
275+
self.metrics['rouge-1_cnt'] += 1
276+
self.metrics['rouge-2_cnt'] += 1
277+
self.metrics['rouge-L_cnt'] += 1
228278

229279
# Ranking metrics.
230280
self._update_ranking_metrics(observation, labels)
231281

232282
# User-reported metrics
233283
if 'metrics' in observation:
234284
for k, v in observation['metrics'].items():
235-
if k not in ['correct', 'f1', 'hits@k', 'bleu']:
285+
if k not in ['correct', 'f1', 'hits@k', 'bleu', 'rouge-1',
286+
'rouge-2', 'rouge-L']:
236287
if k in self.metrics_list:
237288
with self._lock():
238289
self.metrics[k] += v

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,4 @@ sphinx_rtd_theme
2323
tqdm
2424
websocket-client
2525
websocket-server
26+
py-rouge

tests/test_eval_model.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def test_output(self):
1717
"""Test output of running eval_model"""
1818
parser = setup_args()
1919
parser.set_defaults(
20-
task='tasks.repeat:RepeatTeacher:10',
20+
task='integration_tests',
2121
model='repeat_label',
2222
datatype='valid',
2323
num_examples=5,
@@ -30,14 +30,17 @@ def test_output(self):
3030

3131
# decode the output
3232
scores = str_output.split("\n---\n")
33+
3334
for i in range(1, len(scores)):
3435
score = ast.literal_eval(scores[i])
3536
# check totals
36-
self.assertTrue(score['exs'] == i,
37-
"Total is incorrect")
37+
self.assertEqual(score['exs'], i, "Total is incorrect")
3838
# accuracy should be one
39-
self.assertTrue(score['accuracy'] == 1,
40-
"accuracy != 1")
39+
self.assertEqual(score['accuracy'], 1, "accuracy != 1")
40+
if 'rouge-1' in score:
41+
self.assertEqual(score['rouge-1'], 1, 'rouge1 != 1')
42+
self.assertEqual(score['rouge-2'], 1, 'rouge-2 != 1')
43+
self.assertEqual(score['rouge-L'], 1, 'rouge-L != 1')
4144

4245

4346
if __name__ == '__main__':

0 commit comments

Comments
 (0)