Skip to content

Commit 241a8c3

Browse files
rhjohnstonepre-commit-ci[bot]rusty1s
authored
Mean reciprocal rank metric (#9632)
Addresses #9631. Implements `LinkPredMRR` as a `LinkPredMetric`, with test included. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: rusty1s <[email protected]>
1 parent 3f4f1a0 commit 241a8c3

File tree

4 files changed

+40
-2
lines changed

4 files changed

+40
-2
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
77

88
### Added
99

10+
- Added the `LinkPredMRR` metric ([#9632](https://github.com/pyg-team/pytorch_geometric/pull/9632))
1011
- Added PyTorch 2.4 support ([#9594](https://github.com/pyg-team/pytorch_geometric/pull/9594))
1112
- Added `utils.normalize_edge_index` for symmetric/asymmetric normalization of graph edges ([#9554](https://github.com/pyg-team/pytorch_geometric/pull/9554))
1213
- Added the `RemoveSelfLoops` transformation ([#9562](https://github.com/pyg-team/pytorch_geometric/pull/9562))

test/metrics/test_link_pred_metric.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from torch_geometric.metrics import (
77
LinkPredF1,
88
LinkPredMAP,
9+
LinkPredMRR,
910
LinkPredNDCG,
1011
LinkPredPrecision,
1112
LinkPredRecall,
@@ -98,3 +99,15 @@ def test_ndcg():
9899
result = metric.compute()
99100

100101
assert float(result) == pytest.approx(0.6934264)
102+
103+
104+
def test_mrr():
105+
pred_mat = torch.tensor([[1, 0], [1, 2], [0, 2], [0, 1]])
106+
edge_label_index = torch.tensor([[0, 0, 2, 2, 3], [0, 1, 2, 1, 2]])
107+
108+
metric = LinkPredMRR(k=2)
109+
assert str(metric) == 'LinkPredMRR(k=2)'
110+
metric.update(pred_mat, edge_label_index)
111+
result = metric.compute()
112+
113+
assert float(result) == pytest.approx((1 + 0.5 + 0) / 3)
Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
11
# flake8: noqa
22

3-
from .link_pred import (LinkPredPrecision, LinkPredRecall, LinkPredF1,
4-
LinkPredMAP, LinkPredNDCG)
3+
from .link_pred import (
4+
LinkPredPrecision,
5+
LinkPredRecall,
6+
LinkPredF1,
7+
LinkPredMAP,
8+
LinkPredNDCG,
9+
LinkPredMRR,
10+
)
511

612
link_pred_metrics = [
713
'LinkPredPrecision',
814
'LinkPredRecall',
915
'LinkPredF1',
1016
'LinkPredMAP',
1117
'LinkPredNDCG',
18+
'LinkPredMRR',
1219
]
1320

1421
__all__ = link_pred_metrics

torch_geometric/metrics/link_pred.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,3 +216,20 @@ def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor:
216216
out = dcg / idcg
217217
out[out.isnan() | out.isinf()] = 0.0
218218
return out
219+
220+
221+
class LinkPredMRR(LinkPredMetric):
222+
r"""A link prediction metric to compute the MRR @ :math:`k` (Mean
223+
Reciprocal Rank).
224+
225+
Args:
226+
k (int): The number of top-:math:`k` predictions to evaluate against.
227+
"""
228+
higher_is_better: bool = True
229+
230+
def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor:
231+
rank = pred_isin_mat.type(torch.uint8).argmax(dim=-1)
232+
is_correct = pred_isin_mat.gather(1, rank.view(-1, 1)).view(-1)
233+
reciprocals = 1.0 / (rank + 1)
234+
reciprocals[~is_correct] = 0.0
235+
return reciprocals

0 commit comments

Comments
 (0)