From 8db0702d5242336f983b11a087fc96a474e09586 Mon Sep 17 00:00:00 2001
From: Matthias Fey <matthias.fey@tu-dortmund.de>
Date: Wed, 15 Jan 2025 05:38:45 +0100
Subject: [PATCH] More efficient way to compute weighted `IDCG` in
 `LinkPredNDCG` (#9946)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
---
 test/metrics/test_link_pred_metric.py | 12 ++++---
 torch_geometric/metrics/link_pred.py  | 45 ++++++++++++++-------------
 2 files changed, 31 insertions(+), 26 deletions(-)

diff --git a/test/metrics/test_link_pred_metric.py b/test/metrics/test_link_pred_metric.py
index 8e1964fc8b0a..d459c6d8e4b1 100644
--- a/test/metrics/test_link_pred_metric.py
+++ b/test/metrics/test_link_pred_metric.py
@@ -112,14 +112,13 @@ def test_map():
 
 def test_ndcg():
     pred_index_mat = torch.tensor([[1, 0], [1, 2], [0, 2]])
-    edge_label_index = torch.tensor([[0, 0, 2, 2], [0, 1, 2, 1]])
-    edge_label_weight = torch.tensor([1.0, 2.0, 3.0, 0.5])
+    edge_label_index = torch.tensor([[0, 0, 0, 2, 2], [0, 1, 2, 2, 1]])
+    edge_label_weight = torch.tensor([1.0, 2.0, 0.1, 3.0, 0.5])
 
     metric = LinkPredNDCG(k=2)
     assert str(metric) == 'LinkPredNDCG(k=2)'
     metric.update(pred_index_mat, edge_label_index)
     result = metric.compute()
-
     assert float(result) == pytest.approx(0.6934264)
 
     # Test with `k > pred_index_mat.size(1)`:
@@ -134,9 +133,14 @@ def test_ndcg():
 
     metric.update(pred_index_mat, edge_label_index, edge_label_weight)
     result = metric.compute()
-
+    metric.reset()
     assert float(result) == pytest.approx(0.7854486)
 
+    perm = torch.randperm(edge_label_weight.size(0))
+    metric.update(pred_index_mat, edge_label_index[:, perm],
+                  edge_label_weight[perm])
+    assert metric.compute() == result
+
     # Test with `k > pred_index_mat.size(1)`:
     metric.update(pred_index_mat[:, :1], edge_label_index, edge_label_weight)
     metric.compute()
diff --git a/torch_geometric/metrics/link_pred.py b/torch_geometric/metrics/link_pred.py
index fafd0c9d2f3f..7f30429162fb 100644
--- a/torch_geometric/metrics/link_pred.py
+++ b/torch_geometric/metrics/link_pred.py
@@ -4,7 +4,6 @@
 import torch
 from torch import Tensor
 
-from torch_geometric.index import index2ptr
 from torch_geometric.utils import cumsum, scatter
 
 try:
@@ -425,27 +424,29 @@ def _compute(self, data: LinkPredMetricData) -> Tensor:
             idcg = self.idcg[data.label_count.clamp(max=self.k)]
         else:
             assert data.edge_label_weight is not None
-            # Sort weights in buckets via two sorts:
-            weight, perm = data.edge_label_weight.sort(descending=True)
-            batch = data.edge_label_index[0][perm]
-            batch, perm = torch.sort(batch, stable=True)
-            weight = weight[perm]
-
-            # Shrink buckets that are larger than `k`:
-            arange = torch.arange(batch.size(0), device=batch.device)
-            ptr = index2ptr(batch, size=data.pred_index_mat.size(0))
-            batched_arange = arange - ptr[batch]
-            mask = batched_arange < self.k
-            batch = batch[mask]
-            batched_arange = batched_arange[mask]
-            weight = weight[mask]
-
-            # Compute ideal relevance matrix:
-            irel_mat = weight.new_zeros(data.pred_index_mat.size(0) * self.k)
-            irel_mat[batch * self.k + batched_arange] = weight
-            irel_mat = irel_mat.view(-1, self.k)
-
-            idcg = (irel_mat / self.discount.view(1, -1)).sum(dim=-1)
+            # Sort weights within example-wise buckets via two sorts to get the
+            # local index order within buckets:
+            weight, batch = data.edge_label_weight, data.edge_label_index[0]
+            perm1 = weight.argsort(descending=True)
+            perm2 = batch[perm1].argsort(stable=True)
+            global_index = torch.empty_like(perm1)
+            global_index[perm1[perm2]] = torch.arange(
+                global_index.size(0), device=global_index.device)
+            local_index = global_index - cumsum(data.label_count)[batch]
+
+            # Get the discount per local index:
+            discount = torch.cat([
+                self.discount,
+                self.discount.new_full((1, ), fill_value=float('inf')),
+            ])
+            discount = discount[local_index.clamp(max=self.k + 1)]
+
+            idcg = scatter(  # Apply discount and aggregate:
+                weight / discount,
+                batch,
+                dim_size=data.pred_index_mat.size(0),
+                reduce='sum',
+            )
 
         out = dcg / idcg
         out[out.isnan() | out.isinf()] = 0.0