From 8db0702d5242336f983b11a087fc96a474e09586 Mon Sep 17 00:00:00 2001 From: Matthias Fey <matthias.fey@tu-dortmund.de> Date: Wed, 15 Jan 2025 05:38:45 +0100 Subject: [PATCH] More efficient way to compute weighted `IDCG` in `LinkPredNDCG` (#9946) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- test/metrics/test_link_pred_metric.py | 12 ++++--- torch_geometric/metrics/link_pred.py | 45 ++++++++++++++------------- 2 files changed, 31 insertions(+), 26 deletions(-) diff --git a/test/metrics/test_link_pred_metric.py b/test/metrics/test_link_pred_metric.py index 8e1964fc8b0a..d459c6d8e4b1 100644 --- a/test/metrics/test_link_pred_metric.py +++ b/test/metrics/test_link_pred_metric.py @@ -112,14 +112,13 @@ def test_map(): def test_ndcg(): pred_index_mat = torch.tensor([[1, 0], [1, 2], [0, 2]]) - edge_label_index = torch.tensor([[0, 0, 2, 2], [0, 1, 2, 1]]) - edge_label_weight = torch.tensor([1.0, 2.0, 3.0, 0.5]) + edge_label_index = torch.tensor([[0, 0, 0, 2, 2], [0, 1, 2, 2, 1]]) + edge_label_weight = torch.tensor([1.0, 2.0, 0.1, 3.0, 0.5]) metric = LinkPredNDCG(k=2) assert str(metric) == 'LinkPredNDCG(k=2)' metric.update(pred_index_mat, edge_label_index) result = metric.compute() - assert float(result) == pytest.approx(0.6934264) # Test with `k > pred_index_mat.size(1)`: @@ -134,9 +133,14 @@ def test_ndcg(): metric.update(pred_index_mat, edge_label_index, edge_label_weight) result = metric.compute() - + metric.reset() assert float(result) == pytest.approx(0.7854486) + perm = torch.randperm(edge_label_weight.size(0)) + metric.update(pred_index_mat, edge_label_index[:, perm], + edge_label_weight[perm]) + assert metric.compute() == result + # Test with `k > pred_index_mat.size(1)`: metric.update(pred_index_mat[:, :1], edge_label_index, edge_label_weight) metric.compute() diff --git a/torch_geometric/metrics/link_pred.py b/torch_geometric/metrics/link_pred.py index fafd0c9d2f3f..7f30429162fb 100644 --- a/torch_geometric/metrics/link_pred.py +++ b/torch_geometric/metrics/link_pred.py @@ -4,7 +4,6 @@ import torch from torch import Tensor -from torch_geometric.index import index2ptr from torch_geometric.utils import cumsum, scatter try: @@ -425,27 +424,29 @@ def _compute(self, data: LinkPredMetricData) -> Tensor: idcg = self.idcg[data.label_count.clamp(max=self.k)] else: assert data.edge_label_weight is not None - # Sort weights in buckets via two sorts: - weight, perm = data.edge_label_weight.sort(descending=True) - batch = data.edge_label_index[0][perm] - batch, perm = torch.sort(batch, stable=True) - weight = weight[perm] - - # Shrink buckets that are larger than `k`: - arange = torch.arange(batch.size(0), device=batch.device) - ptr = index2ptr(batch, size=data.pred_index_mat.size(0)) - batched_arange = arange - ptr[batch] - mask = batched_arange < self.k - batch = batch[mask] - batched_arange = batched_arange[mask] - weight = weight[mask] - - # Compute ideal relevance matrix: - irel_mat = weight.new_zeros(data.pred_index_mat.size(0) * self.k) - irel_mat[batch * self.k + batched_arange] = weight - irel_mat = irel_mat.view(-1, self.k) - - idcg = (irel_mat / self.discount.view(1, -1)).sum(dim=-1) + # Sort weights within example-wise buckets via two sorts to get the + # local index order within buckets: + weight, batch = data.edge_label_weight, data.edge_label_index[0] + perm1 = weight.argsort(descending=True) + perm2 = batch[perm1].argsort(stable=True) + global_index = torch.empty_like(perm1) + global_index[perm1[perm2]] = torch.arange( + global_index.size(0), device=global_index.device) + local_index = global_index - cumsum(data.label_count)[batch] + + # Get the discount per local index: + discount = torch.cat([ + self.discount, + self.discount.new_full((1, ), fill_value=float('inf')), + ]) + discount = discount[local_index.clamp(max=self.k + 1)] + + idcg = scatter( # Apply discount and aggregate: + weight / discount, + batch, + dim_size=data.pred_index_mat.size(0), + reduce='sum', + ) out = dcg / idcg out[out.isnan() | out.isinf()] = 0.0