Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added the `LinkPredMRR` metric ([#9632](https://github.com/pyg-team/pytorch_geometric/pull/9632))
- Added PyTorch 2.4 support ([#9594](https://github.com/pyg-team/pytorch_geometric/pull/9594))
- Added `utils.normalize_edge_index` for symmetric/asymmetric normalization of graph edges ([#9554](https://github.com/pyg-team/pytorch_geometric/pull/9554))
- Added the `RemoveSelfLoops` transformation ([#9562](https://github.com/pyg-team/pytorch_geometric/pull/9562))
Expand Down
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -383,9 +383,9 @@ where `${CUDA}` should be replaced by either `cpu`, `cu118`, `cu121`, or `cu124`

| | `cpu` | `cu118` | `cu121` | `cu124` |
| ----------- | ----- | ------- | ------- | ------- |
| **Linux** | ✅ | ✅ | ✅ | ✅ |
| **Windows** | ✅ | ✅ | ✅ | ✅ |
| **macOS** | ✅ | | | |
| **Linux** | ✅ | ✅ | ✅ | ✅ |
| **Windows** | ✅ | ✅ | ✅ | ✅ |
| **macOS** | ✅ | | | |

#### PyTorch 2.3

Expand All @@ -399,9 +399,9 @@ where `${CUDA}` should be replaced by either `cpu`, `cu118`, or `cu121` dependin

| | `cpu` | `cu118` | `cu121` |
| ----------- | ----- | ------- | ------- |
| **Linux** | ✅ | ✅ | ✅ |
| **Windows** | ✅ | ✅ | ✅ |
| **macOS** | ✅ | | |
| **Linux** | ✅ | ✅ | ✅ |
| **Windows** | ✅ | ✅ | ✅ |
| **macOS** | ✅ | | |

**Note:** Binaries of older versions are also provided for PyTorch 1.4.0, PyTorch 1.5.0, PyTorch 1.6.0, PyTorch 1.7.0/1.7.1, PyTorch 1.8.0/1.8.1, PyTorch 1.9.0, PyTorch 1.10.0/1.10.1/1.10.2, PyTorch 1.11.0, PyTorch 1.12.0/1.12.1, PyTorch 1.13.0/1.13.1, PyTorch 2.0.0/2.0.1, PyTorch 2.1.0/2.1.1/2.1.2, and PyTorch 2.2.0/2.2.1/2.2.2 (following the same procedure).
**For older versions, you might need to explicitly specify the latest supported version number** or install via `pip install --no-index` in order to prevent a manual installation from source.
Expand Down
13 changes: 13 additions & 0 deletions test/metrics/test_link_pred_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch_geometric.metrics import (
LinkPredF1,
LinkPredMAP,
LinkPredMRR,
LinkPredNDCG,
LinkPredPrecision,
LinkPredRecall,
Expand Down Expand Up @@ -98,3 +99,15 @@ def test_ndcg():
result = metric.compute()

assert float(result) == pytest.approx(0.6934264)


def test_mrr():
pred_mat = torch.tensor([[1, 0], [1, 2], [0, 2], [0, 1]])
edge_label_index = torch.tensor([[0, 0, 2, 2, 3], [0, 1, 2, 1, 2]])

metric = LinkPredMRR(k=2)
assert str(metric) == 'LinkPredMRR(k=2)'
metric.update(pred_mat, edge_label_index)
result = metric.compute()

assert float(result) == pytest.approx((1 + 0.5 + 0) / 3)
11 changes: 9 additions & 2 deletions torch_geometric/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
# flake8: noqa

from .link_pred import (LinkPredPrecision, LinkPredRecall, LinkPredF1,
LinkPredMAP, LinkPredNDCG)
from .link_pred import (
LinkPredPrecision,
LinkPredRecall,
LinkPredF1,
LinkPredMAP,
LinkPredNDCG,
LinkPredMRR,
)

link_pred_metrics = [
'LinkPredPrecision',
'LinkPredRecall',
'LinkPredF1',
'LinkPredMAP',
'LinkPredNDCG',
'LinkPredMRR',
]

__all__ = link_pred_metrics
17 changes: 17 additions & 0 deletions torch_geometric/metrics/link_pred.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,3 +216,20 @@ def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor:
out = dcg / idcg
out[out.isnan() | out.isinf()] = 0.0
return out


class LinkPredMRR(LinkPredMetric):
r"""A link prediction metric to compute the MRR @ :math:`k` (Mean
Reciprocal Rank).

Args:
k (int): The number of top-:math:`k` predictions to evaluate against.
"""
higher_is_better: bool = True

def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor:
rank = pred_isin_mat.type(torch.uint8).argmax(dim=-1)
is_correct = pred_isin_mat.gather(1, rank.view(-1, 1)).view(-1)
reciprocals = 1.0 / (rank + 1)
reciprocals[~is_correct] = 0.0
return reciprocals